Spaces:
Running
Running
| """ | |
| Research -> Interactive Explainer Environment (multi-step, async). | |
| Episode flow: | |
| 1. reset() → agent gets a topic + tier | |
| 2. step(explore) × 0..MAX_EXPLORE → agent calls research tools | |
| 3. step(generate) × 1 → agent produces marimo/manim code | |
| 4. step(repair) × 0..MAX_REPAIR → agent fixes lint/build errors if needed | |
| Each step returns a per-step reward. The final generate step also includes | |
| a generation reward that accounts for how well the code uses the research. | |
| The environment supports async via reset_async() / step_async() overrides. | |
| OpenEnv's HTTP server detects these and calls them directly (no thread pool). | |
| """ | |
| import random | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS, clamp_action_reward | |
| from ..models import ExplainerAction, ExplainerObservation | |
| from ..research import AVAILABLE_TOOLS, run_research_tool | |
| from ..rewards.exploration import compute_explore_reward | |
| from ..rewards.generation import adjust_repair_reward, compute_generate_reward | |
| from ..rewards.sandbox import validate_code | |
| from ..task_bank import ALL_TASKS, EASY_TASKS, HARD_TASKS, MEDIUM_TASKS, Task | |
| except ImportError: | |
| from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS, clamp_action_reward | |
| from models import ExplainerAction, ExplainerObservation | |
| from research import AVAILABLE_TOOLS, run_research_tool | |
| from rewards.exploration import compute_explore_reward | |
| from rewards.generation import adjust_repair_reward, compute_generate_reward | |
| from rewards.sandbox import validate_code | |
| from task_bank import ALL_TASKS, EASY_TASKS, HARD_TASKS, MEDIUM_TASKS, Task | |
| MB002_REPAIR_HINT = ( | |
| "MB002 repair checklist: Marimo treats every non-underscore assignment as a " | |
| "global notebook variable, including `for` loop variables. Audit the whole " | |
| "file and rename cell-local names to private names everywhere: `arr` -> " | |
| "`_arr`, `target` -> `_target`, `i` -> `_i`, `t` -> `_t`, `freqs` -> " | |
| "`_freqs`, `fig` -> `_fig`, `ax` -> `_ax`. Public names should only be used " | |
| "for values intentionally passed to later cells, and each public name may be " | |
| "defined once globally." | |
| ) | |
| def _render_errors_with_hints(errors: str, error_codes: list[str]) -> str: | |
| if "MB002" not in error_codes: | |
| return errors | |
| return f"{errors}\n\n{MB002_REPAIR_HINT}" | |
| class ExplainerEnvironment(Environment): | |
| """ | |
| Multi-step Research → Interactive Explainer environment. | |
| Phase 1 (explore): agent issues search queries, receives papers/wiki sections. | |
| Phase 2 (generate): agent produces marimo/manim code using the research. | |
| Supports async via reset_async() / step_async() — OpenEnv's server detects | |
| the overrides and awaits them directly instead of using a thread pool. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| super().__init__() | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._current_task: Task | None = None | |
| self._difficulty_pool: list[Task] = EASY_TASKS | |
| self._accumulated_context: list[str] = [] | |
| self._explore_actions: list[str] = [] | |
| self._used_tools: set[str] = set() | |
| self._explore_steps: int = 0 | |
| self._repair_steps: int = 0 | |
| self._phase: str = "explore" | |
| self._done: bool = False | |
| self._last_code: str = "" | |
| self._last_format: str = "marimo" | |
| self._last_narration: str = "" | |
| self._last_errors: str = "" | |
| self._last_error_codes: list[str] = [] | |
| # ------------------------------------------------------------------ | |
| # Sync interface (fallback — OpenEnv prefers async when overridden) | |
| # ------------------------------------------------------------------ | |
| def reset(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation: | |
| """Sample a task and return the initial observation (sync).""" | |
| return self._do_reset(seed=seed, episode_id=episode_id, **kwargs) | |
| def step(self, action: ExplainerAction, timeout_s=None, **kwargs) -> ExplainerObservation: | |
| """Route to explore or generate handler (sync — explore uses blocking fallback).""" | |
| import asyncio | |
| self._state.step_count += 1 | |
| task = self._current_task | |
| if task is None: | |
| return ExplainerObservation( | |
| feedback="Error: no task set. Call reset() first.", | |
| done=True, | |
| reward=-1.0, | |
| ) | |
| if self._done: | |
| return self._make_obs( | |
| task, | |
| phase="done", | |
| feedback="Episode is already done. Call reset() to start a new one.", | |
| reward=0.0, | |
| done=True, | |
| ) | |
| try: | |
| if action.action_type == "explore": | |
| # Run async explore in a new event loop for sync callers | |
| return asyncio.run(self._handle_explore(action, task)) | |
| elif action.action_type == "generate": | |
| return self._handle_generate(action, task) | |
| elif action.action_type == "repair": | |
| return self._handle_repair(action, task) | |
| else: | |
| return self._make_obs( | |
| task, | |
| phase="explore", | |
| feedback=f"Unknown action_type: {action.action_type}", | |
| reward=0.0, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| return self._make_obs( | |
| task, | |
| phase="done", | |
| feedback=f"Environment error: {e}", | |
| reward=0.0, | |
| done=True, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Async interface (preferred — OpenEnv detects these overrides) | |
| # ------------------------------------------------------------------ | |
| async def reset_async(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation: | |
| """Sample a task and return the initial observation (async).""" | |
| return self._do_reset(seed=seed, episode_id=episode_id, **kwargs) | |
| async def step_async(self, action: ExplainerAction, timeout_s=None, **kwargs) -> ExplainerObservation: | |
| """Route to explore or generate handler (async).""" | |
| self._state.step_count += 1 | |
| task = self._current_task | |
| if task is None: | |
| return ExplainerObservation( | |
| feedback="Error: no task set. Call reset() first.", | |
| done=True, | |
| reward=-1.0, | |
| ) | |
| if self._done: | |
| return self._make_obs( | |
| task, | |
| phase="done", | |
| feedback="Episode is already done. Call reset() to start a new one.", | |
| reward=0.0, | |
| done=True, | |
| ) | |
| try: | |
| if action.action_type == "explore": | |
| return await self._handle_explore(action, task) | |
| elif action.action_type == "generate": | |
| return self._handle_generate(action, task) | |
| elif action.action_type == "repair": | |
| return self._handle_repair(action, task) | |
| else: | |
| return self._make_obs( | |
| task, | |
| phase="explore", | |
| feedback=f"Unknown action_type: {action.action_type}", | |
| reward=0.0, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| return self._make_obs( | |
| task, | |
| phase="done", | |
| feedback=f"Environment error: {e}", | |
| reward=0.0, | |
| done=True, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal | |
| # ------------------------------------------------------------------ | |
| def _do_reset(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation: | |
| """Shared reset logic (no I/O, so sync is fine).""" | |
| self._state = State( | |
| episode_id=episode_id or str(uuid4()), step_count=0 | |
| ) | |
| self._accumulated_context = [] | |
| self._explore_actions = [] | |
| self._used_tools = set() | |
| self._explore_steps = 0 | |
| self._repair_steps = 0 | |
| self._phase = "explore" | |
| self._done = False | |
| self._last_code = "" | |
| self._last_format = "marimo" | |
| self._last_narration = "" | |
| self._last_errors = "" | |
| self._last_error_codes = [] | |
| # Allow selecting a specific task by topic name | |
| topic = kwargs.get("topic", None) | |
| if topic: | |
| match = next((t for t in ALL_TASKS if t.topic == topic), None) | |
| if match: | |
| self._current_task = match | |
| else: | |
| # Fallback to random if topic not found | |
| rng = random.Random(seed) if seed is not None else random.Random() | |
| self._current_task = rng.choice(ALL_TASKS) | |
| else: | |
| difficulty = kwargs.get("difficulty", None) | |
| if difficulty == "medium": | |
| pool = MEDIUM_TASKS | |
| elif difficulty == "hard": | |
| pool = HARD_TASKS | |
| elif difficulty == "easy": | |
| pool = EASY_TASKS | |
| else: | |
| pool = self._difficulty_pool | |
| rng = random.Random(seed) if seed is not None else random.Random() | |
| self._current_task = rng.choice(pool) if pool else rng.choice(ALL_TASKS) | |
| t = self._current_task | |
| return ExplainerObservation( | |
| topic=t.topic, | |
| content=t.content, | |
| tier=t.tier, | |
| keywords=t.keywords, | |
| data_available=t.data_available, | |
| difficulty=t.difficulty, | |
| phase="explore", | |
| feedback=( | |
| "Research phase: choose a tool and query relevant to the topic. " | |
| f"Available tools: {', '.join(AVAILABLE_TOOLS)}." | |
| ), | |
| search_results="", | |
| explored_context="", | |
| explore_steps_left=MAX_EXPLORE_STEPS, | |
| repair_attempts_left=MAX_REPAIR_STEPS, | |
| available_tools=list(AVAILABLE_TOOLS), | |
| done=False, | |
| reward=0.0, | |
| ) | |
| async def _handle_explore(self, action: ExplainerAction, task: Task) -> ExplainerObservation: | |
| """Process an explore action: call a research tool and score the result.""" | |
| if self._phase not in {"explore", "generate"}: | |
| return self._make_obs( | |
| task, | |
| phase=self._phase, | |
| feedback=f"Cannot explore during phase '{self._phase}'.", | |
| reward=0.0, | |
| ) | |
| if self._explore_steps >= MAX_EXPLORE_STEPS: | |
| self._phase = "generate" | |
| return self._make_obs( | |
| task, | |
| phase="generate", | |
| feedback="Max explore steps reached. You must now generate.", | |
| reward=0.0, | |
| ) | |
| self._explore_steps += 1 | |
| query = action.query.strip() | |
| intent = action.intent.strip() | |
| tool = action.tool or "search_wikipedia" | |
| if not query: | |
| return self._make_obs( | |
| task, | |
| phase="explore", | |
| feedback="Empty query. Provide a search query.", | |
| reward=0.0, | |
| ) | |
| previous_context = list(self._accumulated_context) | |
| previous_actions = list(self._explore_actions) | |
| used_tools = set(self._used_tools) | |
| result = await run_research_tool(tool, query, intent) | |
| results_text = result.render() | |
| self._explore_actions.append(_explore_action_text(tool, query, intent)) | |
| if result.ok: | |
| self._accumulated_context.append(result.text) | |
| self._used_tools.add(tool) | |
| # Compute per-step exploration reward | |
| reward, components = compute_explore_reward( | |
| query=query, | |
| tool=tool, | |
| intent=intent, | |
| result=result, | |
| topic=task.topic, | |
| keywords_csv=task.keywords, | |
| task_content=task.content, | |
| difficulty=task.difficulty, | |
| previous_context=previous_context, | |
| accumulated_context=self._accumulated_context, | |
| used_tools=used_tools, | |
| previous_actions=previous_actions, | |
| ) | |
| steps_left = MAX_EXPLORE_STEPS - self._explore_steps | |
| if steps_left > 1: | |
| phase = "explore" | |
| hint = f"Research going well — {steps_left} more steps available. Keep searching or move to generation." | |
| elif steps_left == 1: | |
| phase = "explore" | |
| hint = "Last research step available. Search for any missing context, or proceed to generate." | |
| else: | |
| phase = "generate" | |
| hint = "Research phase complete. Time to generate your explanation." | |
| self._phase = phase | |
| top_chunks = _top_chunks_payload(result.chunks) | |
| return self._make_obs( | |
| task, | |
| phase=phase, | |
| feedback=f"{hint}\nTool: {tool}\nReward: {components}", | |
| search_results=results_text, | |
| top_chunks=top_chunks, | |
| reward=reward, | |
| metadata={ | |
| "step": self._state.step_count, | |
| "phase": "explore", | |
| "tool": tool, | |
| "source_count": len(result.chunks), | |
| "top_chunks": top_chunks, | |
| "error": result.error, | |
| **components, | |
| }, | |
| ) | |
| def _handle_generate(self, action: ExplainerAction, task: Task) -> ExplainerObservation: | |
| """Process a generate action: run sandbox, maybe open repair phase.""" | |
| if self._phase not in {"explore", "generate"}: | |
| return self._make_obs( | |
| task, | |
| phase=self._phase, | |
| feedback=f"Cannot generate during phase '{self._phase}'.", | |
| reward=0.0, | |
| ) | |
| fmt = action.format or "marimo" | |
| code = action.code | |
| narration = action.narration | |
| # Penalise generating without any exploration | |
| if self._explore_steps == 0: | |
| skip_penalty = -0.1 | |
| penalty_msg = "Warning: generating without any research. -0.1 penalty." | |
| else: | |
| skip_penalty = 0.0 | |
| penalty_msg = "" | |
| sandbox = validate_code(fmt, code) | |
| # Generation reward | |
| reward, components = compute_generate_reward( | |
| code=code, | |
| fmt=fmt, | |
| narration=narration, | |
| task=task, | |
| exec_success=sandbox.exec_success, | |
| accumulated_context=self._accumulated_context, | |
| static_check_passed=sandbox.check_passed, | |
| error_codes=sandbox.error_codes, | |
| ) | |
| reward = clamp_action_reward(reward + skip_penalty) | |
| components["generate_total"] = round(reward, 4) | |
| self._last_code = code | |
| self._last_format = fmt | |
| self._last_narration = narration | |
| rendered_errors = _render_errors_with_hints(sandbox.render_errors(), sandbox.error_codes) | |
| self._last_errors = rendered_errors | |
| self._last_error_codes = sandbox.error_codes | |
| # Feedback | |
| parts = [] | |
| if penalty_msg: | |
| parts.append(penalty_msg) | |
| if not sandbox.parses: | |
| parts.append("SYNTAX ERROR: code does not parse.") | |
| elif not sandbox.exec_success: | |
| parts.append(f"EXECUTION FAILED: {rendered_errors}") | |
| else: | |
| parts.append(f"EXECUTION OK: {sandbox.message}") | |
| parts.append( | |
| f"Reward: {', '.join(f'{k}={v}' for k, v in components.items())}" | |
| ) | |
| done = sandbox.exec_success or self._repair_steps >= MAX_REPAIR_STEPS | |
| phase = "done" if done else "repair" | |
| self._phase = phase | |
| self._done = done | |
| if not done: | |
| parts.append( | |
| f"Repair phase: {MAX_REPAIR_STEPS} attempts available. " | |
| "Submit a revised artifact using the error feedback." | |
| ) | |
| return self._make_obs( | |
| task, | |
| phase=phase, | |
| feedback="\n".join(parts), | |
| reward=reward, | |
| done=done, | |
| last_errors="" if sandbox.exec_success else rendered_errors, | |
| metadata={ | |
| "step": self._state.step_count, | |
| "phase": "generate", | |
| "explore_steps_used": self._explore_steps, | |
| "sandbox_message": sandbox.message, | |
| "error_codes": sandbox.error_codes, | |
| **components, | |
| }, | |
| ) | |
| def _handle_repair(self, action: ExplainerAction, task: Task) -> ExplainerObservation: | |
| """Process one repair attempt after a failed generate action.""" | |
| if self._phase != "repair": | |
| return self._make_obs( | |
| task, | |
| phase=self._phase, | |
| feedback="Repair is only available after a failed generate step.", | |
| reward=0.0, | |
| done=self._done, | |
| ) | |
| if self._repair_steps >= MAX_REPAIR_STEPS: | |
| self._phase = "done" | |
| self._done = True | |
| return self._make_obs( | |
| task, | |
| phase="done", | |
| feedback="No repair attempts left.", | |
| reward=0.0, | |
| done=True, | |
| ) | |
| self._repair_steps += 1 | |
| fmt = action.format or self._last_format or "marimo" | |
| code = action.code | |
| narration = action.narration or self._last_narration | |
| previous_code = self._last_code | |
| previous_errors = list(self._last_error_codes) | |
| sandbox = validate_code(fmt, code) | |
| base_reward, components = compute_generate_reward( | |
| code=code, | |
| fmt=fmt, | |
| narration=narration, | |
| task=task, | |
| exec_success=sandbox.exec_success, | |
| accumulated_context=self._accumulated_context, | |
| static_check_passed=sandbox.check_passed, | |
| error_codes=sandbox.error_codes, | |
| ) | |
| repair_reward, repair_components = adjust_repair_reward( | |
| base_reward, | |
| repair_success=sandbox.exec_success, | |
| previous_error_codes=previous_errors, | |
| new_error_codes=sandbox.error_codes, | |
| previous_code=previous_code, | |
| repaired_code=code, | |
| ) | |
| components.update(repair_components) | |
| self._last_code = code | |
| self._last_format = fmt | |
| self._last_narration = narration | |
| rendered_errors = _render_errors_with_hints(sandbox.render_errors(), sandbox.error_codes) | |
| self._last_errors = rendered_errors | |
| self._last_error_codes = sandbox.error_codes | |
| attempts_left = MAX_REPAIR_STEPS - self._repair_steps | |
| done = sandbox.exec_success or attempts_left <= 0 | |
| phase = "done" if done else "repair" | |
| self._phase = phase | |
| self._done = done | |
| status = "REPAIR OK" if sandbox.exec_success else "REPAIR FAILED" | |
| feedback_parts = [ | |
| f"{status}: {sandbox.message if sandbox.exec_success else rendered_errors}", | |
| f"Reward: {', '.join(f'{k}={v}' for k, v in components.items())}", | |
| ] | |
| if not done: | |
| feedback_parts.append( | |
| f"Repair phase continues: {attempts_left} repair attempts left. " | |
| "Submit another corrected artifact using the latest error feedback." | |
| ) | |
| feedback = "\n".join(feedback_parts) | |
| return self._make_obs( | |
| task, | |
| phase=phase, | |
| feedback=feedback, | |
| reward=repair_reward, | |
| done=done, | |
| last_errors="" if sandbox.exec_success else rendered_errors, | |
| metadata={ | |
| "step": self._state.step_count, | |
| "phase": "repair", | |
| "explore_steps_used": self._explore_steps, | |
| "repair_steps_used": self._repair_steps, | |
| "sandbox_message": sandbox.message, | |
| "error_codes": sandbox.error_codes, | |
| **components, | |
| }, | |
| ) | |
| def _make_obs( | |
| self, | |
| task: Task, | |
| *, | |
| phase: str, | |
| feedback: str, | |
| reward: float = 0.0, | |
| done: bool = False, | |
| search_results: str = "", | |
| top_chunks: list[dict] | None = None, | |
| last_errors: str | None = None, | |
| metadata: dict | None = None, | |
| ) -> ExplainerObservation: | |
| """Helper to build a consistent observation.""" | |
| return ExplainerObservation( | |
| topic=task.topic, | |
| content=task.content, | |
| tier=task.tier, | |
| keywords=task.keywords, | |
| data_available=task.data_available, | |
| difficulty=task.difficulty, | |
| phase=phase, | |
| feedback=feedback, | |
| search_results=search_results, | |
| top_chunks=top_chunks or [], | |
| explored_context="\n---\n".join(self._accumulated_context), | |
| explore_steps_left=MAX_EXPLORE_STEPS - self._explore_steps, | |
| repair_attempts_left=MAX_REPAIR_STEPS - self._repair_steps, | |
| last_errors=self._last_errors if last_errors is None else last_errors, | |
| available_tools=list(AVAILABLE_TOOLS), | |
| done=done, | |
| reward=reward, | |
| metadata=metadata or {}, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |
| def _explore_action_text(tool: str, query: str, intent: str) -> str: | |
| return f"{tool} {query.strip()} {intent.strip()}".strip() | |
| def _top_chunks_payload(chunks) -> list[dict]: | |
| return [ | |
| { | |
| "rank": chunk.rank, | |
| "source": chunk.source, | |
| "title": chunk.title, | |
| "url": chunk.url, | |
| "score": round(chunk.score, 4), | |
| "snippet": chunk.text, | |
| } | |
| for chunk in chunks[:5] | |
| ] | |