repl_env-0.2.2 / runner.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
7031bab verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Local recursive RLM runner for repl_env.
This keeps the iterative prompting/orchestration layer outside the environment,
following the same separation used by the official RLM implementation and DSPy:
- `REPLEnvironment` executes code and exposes tools
- `LocalRLMRunner` owns prompting, message history, and recursive child runs
"""
from __future__ import annotations
import re
import time
from dataclasses import dataclass
from typing import Callable
from .local import LocalREPLEnv
from .prompts import (
build_rlm_system_prompt,
build_user_prompt,
extract_code_blocks,
format_observations,
QueryMetadata,
RLM_SYSTEM_PROMPT,
)
from .recursive_backends import BackendLimits, LocalChildRLMBackend, RecursiveBackend
ChatFn = Callable[..., str]
@dataclass
class RLMRunResult:
final_answer: str | None
messages: list[dict[str, str]]
iterations: int
depth: int
child_traces: list[object]
class LocalRLMRunner:
"""Local recursive RLM orchestrator built on top of LocalREPLEnv."""
def __init__(
self,
llm_chat_fn: ChatFn,
*,
system_prompt: str = RLM_SYSTEM_PROMPT,
max_iterations: int = 30,
max_depth: int = 2,
depth: int = 0,
env_max_iterations_multiplier: int = 5,
max_batch_workers: int = 8,
backend_factory: Callable[..., RecursiveBackend] | None = None,
max_children_total: int | None = None,
max_children_per_batch: int | None = None,
result_truncation_limit: int | None = None,
per_child_timeout_s: float | None = None,
on_subcall_start: Callable[[int, str, str], None] | None = None,
on_subcall_complete: Callable[[int, str, float, str | None], None]
| None = None,
verbose: bool = False,
) -> None:
self.llm_chat_fn = llm_chat_fn
self.system_prompt = system_prompt
self.max_iterations = max_iterations
self.max_depth = max_depth
self.depth = depth
self.env_max_iterations_multiplier = env_max_iterations_multiplier
self.max_batch_workers = max_batch_workers
self.backend_factory = backend_factory or self._default_backend_factory
self.max_children_total = max_children_total
self.max_children_per_batch = max_children_per_batch
self.result_truncation_limit = result_truncation_limit
self.per_child_timeout_s = per_child_timeout_s
self.on_subcall_start = on_subcall_start
self.on_subcall_complete = on_subcall_complete
self.verbose = verbose
def _default_backend_factory(
self, llm_chat_fn: ChatFn, **kwargs
) -> RecursiveBackend:
limits = BackendLimits(
max_depth=self.max_depth,
max_batch_workers=self.max_batch_workers,
max_children_total=self.max_children_total,
max_children_per_batch=self.max_children_per_batch,
result_truncation_limit=self.result_truncation_limit,
per_child_timeout_s=self.per_child_timeout_s,
)
return LocalChildRLMBackend(
llm_chat_fn,
runner_factory=LocalRLMRunner,
system_prompt=kwargs["system_prompt"],
max_iterations=kwargs["max_iterations"],
env_max_iterations_multiplier=kwargs["env_max_iterations_multiplier"],
depth=kwargs["depth"],
limits=limits,
on_subcall_start=self.on_subcall_start,
on_subcall_complete=self.on_subcall_complete,
)
def run(
self,
context: str,
task_prompt: str,
*,
model: str | None = None,
timeout_s: float | None = None,
) -> RLMRunResult:
backend = self.backend_factory(
self.llm_chat_fn,
system_prompt=self.system_prompt,
max_iterations=self.max_iterations,
max_depth=self.max_depth,
depth=self.depth,
env_max_iterations_multiplier=self.env_max_iterations_multiplier,
)
with LocalREPLEnv(
llm_query_fn=backend.query,
llm_batch_fn=backend.query_batched,
subcall_fn=backend.recursive_query,
subcall_batch_fn=backend.recursive_query_batched,
) as env:
result = env.reset(
context=context,
task_prompt=task_prompt,
max_iterations=self.max_iterations * self.env_max_iterations_multiplier,
llm_model=model,
)
obs = result.observation
query_metadata = QueryMetadata(
context_lengths=[obs.context_length],
context_total_length=obs.context_length,
context_type="str",
)
messages = build_rlm_system_prompt(self.system_prompt, query_metadata)
messages.append(build_user_prompt(root_prompt=task_prompt, iteration=0))
run_start = time.perf_counter()
for iteration in range(1, self.max_iterations + 1):
# Cooperative timeout check (matches official RLM pattern)
if timeout_s is not None:
elapsed = time.perf_counter() - run_start
if elapsed >= timeout_s:
return RLMRunResult(
final_answer=f"Error: child timeout after {elapsed:.3f}s",
messages=messages,
iterations=iteration - 1,
depth=self.depth,
child_traces=list(getattr(backend, "child_traces", [])),
)
response = self._chat(messages, model)
code_blocks = extract_code_blocks(response)
code_block_observations = []
if self.verbose:
print(
f"[depth={self.depth}] iteration={iteration} code_blocks={len(code_blocks)}"
)
if not code_blocks:
messages.append({"role": "assistant", "content": response})
messages.append(
{
"role": "user",
"content": (
"Please continue by writing Python code in ```repl``` blocks, "
"or submit the final answer with FINAL(...) / FINAL_VAR(...)."
),
}
)
continue
for code in code_blocks:
result = env.execute(code)
code_block_observations.append(result.observation)
# Check for FINAL after all blocks executed (matches official RLM).
# The model expects all blocks to run — it often writes exploration
# code first and FINAL last in the same response.
if any(obs.done for obs in code_block_observations):
return RLMRunResult(
final_answer=env.state().final_answer,
messages=messages
+ [{"role": "assistant", "content": response}],
iterations=iteration,
depth=self.depth,
child_traces=list(getattr(backend, "child_traces", [])),
)
observation_text = format_observations(
code_block_observations, code_blocks=code_blocks
)
next_prompt = build_user_prompt(
root_prompt=task_prompt,
iteration=iteration,
)
messages.append({"role": "assistant", "content": response})
messages.append(
{
"role": "user",
"content": observation_text + "\n\n" + next_prompt["content"],
}
)
# Max iterations exhausted — give the model one final chance to answer
final_answer = env.state().final_answer
if final_answer is None:
final_answer = self._default_answer(messages, model)
return RLMRunResult(
final_answer=final_answer,
messages=messages,
iterations=self.max_iterations,
depth=self.depth,
child_traces=list(getattr(backend, "child_traces", [])),
)
def _default_answer(
self, messages: list[dict[str, str]], model: str | None = None
) -> str | None:
"""Make one final LLM call asking for an answer when iterations are exhausted."""
final_prompt = messages + [
{
"role": "user",
"content": (
"You have run out of REPL iterations. Based on all your work above, "
"provide your best final answer now. Use FINAL(your answer) to submit it. "
"If you stored the answer in a variable, use FINAL_VAR(variable_name) instead. "
"Do not write any more code — just provide the final answer."
),
}
]
try:
response = self._chat(final_prompt, model)
# Try to extract FINAL(...) from the response
match = re.search(r"FINAL\((.*?)\)", response, re.DOTALL)
if match:
return match.group(1).strip()
# If no FINAL pattern, return the raw response as best-effort
return response.strip() if response.strip() else None
except Exception:
return None
def _chat(self, messages: list[dict[str, str]], model: str | None = None) -> str:
try:
return self.llm_chat_fn(messages, model)
except TypeError:
return self.llm_chat_fn(messages)