Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import logging | |
import time | |
from abc import ABC, abstractmethod | |
from asyncio import CancelledError | |
from functools import wraps | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Callable, | |
Dict, | |
List, | |
NoReturn, | |
Optional, | |
Tuple, | |
Type, | |
Union, | |
) | |
from langchain_core.agents import AgentAction, AgentFinish | |
from langchain_core.load.dump import dumpd | |
from langchain_core.outputs import RunInfo | |
from langchain_core.utils.input import get_color_mapping | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManager, | |
AsyncCallbackManagerForChainRun, | |
CallbackManager, | |
CallbackManagerForChainRun, | |
Callbacks, | |
) | |
from langchain.schema import RUN_KEY | |
from langchain.tools import BaseTool | |
from langchain.utilities.asyncio import asyncio_timeout | |
if TYPE_CHECKING: | |
from langchain.agents.agent import AgentExecutor | |
logger = logging.getLogger(__name__) | |
class BaseAgentExecutorIterator(ABC): | |
"""Base class for AgentExecutorIterator.""" | |
def build_callback_manager(self) -> None: | |
pass | |
def rebuild_callback_manager_on_set( | |
setter_method: Callable[..., None] | |
) -> Callable[..., None]: | |
"""Decorator to force setters to rebuild callback mgr""" | |
def wrapper(self: BaseAgentExecutorIterator, *args: Any, **kwargs: Any) -> None: | |
setter_method(self, *args, **kwargs) | |
self.build_callback_manager() | |
return wrapper | |
class AgentExecutorIterator(BaseAgentExecutorIterator): | |
"""Iterator for AgentExecutor.""" | |
def __init__( | |
self, | |
agent_executor: AgentExecutor, | |
inputs: Any, | |
callbacks: Callbacks = None, | |
*, | |
tags: Optional[list[str]] = None, | |
include_run_info: bool = False, | |
async_: bool = False, | |
): | |
""" | |
Initialize the AgentExecutorIterator with the given AgentExecutor, | |
inputs, and optional callbacks. | |
""" | |
self._agent_executor = agent_executor | |
self.inputs = inputs | |
self.async_ = async_ | |
# build callback manager on tags setter | |
self._callbacks = callbacks | |
self.tags = tags | |
self.include_run_info = include_run_info | |
self.run_manager = None | |
self.reset() | |
_callback_manager: Union[AsyncCallbackManager, CallbackManager] | |
_inputs: dict[str, str] | |
_final_outputs: Optional[dict[str, str]] | |
run_manager: Optional[ | |
Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun] | |
] | |
timeout_manager: Any # TODO: Fix a type here; the shim makes it tricky. | |
def inputs(self) -> dict[str, str]: | |
return self._inputs | |
def inputs(self, inputs: Any) -> None: | |
self._inputs = self.agent_executor.prep_inputs(inputs) | |
def callbacks(self) -> Callbacks: | |
return self._callbacks | |
def callbacks(self, callbacks: Callbacks) -> None: | |
"""When callbacks are changed after __init__, rebuild callback mgr""" | |
self._callbacks = callbacks | |
def tags(self) -> Optional[List[str]]: | |
return self._tags | |
def tags(self, tags: Optional[List[str]]) -> None: | |
"""When tags are changed after __init__, rebuild callback mgr""" | |
self._tags = tags | |
def agent_executor(self) -> AgentExecutor: | |
return self._agent_executor | |
def agent_executor(self, agent_executor: AgentExecutor) -> None: | |
self._agent_executor = agent_executor | |
# force re-prep inputs in case agent_executor's prep_inputs fn changed | |
self.inputs = self.inputs | |
def callback_manager(self) -> Union[AsyncCallbackManager, CallbackManager]: | |
return self._callback_manager | |
def build_callback_manager(self) -> None: | |
""" | |
Create and configure the callback manager based on the current | |
callbacks and tags. | |
""" | |
CallbackMgr: Union[Type[AsyncCallbackManager], Type[CallbackManager]] = ( | |
AsyncCallbackManager if self.async_ else CallbackManager | |
) | |
self._callback_manager = CallbackMgr.configure( | |
self.callbacks, | |
self.agent_executor.callbacks, | |
self.agent_executor.verbose, | |
self.tags, | |
self.agent_executor.tags, | |
) | |
def name_to_tool_map(self) -> dict[str, BaseTool]: | |
return {tool.name: tool for tool in self.agent_executor.tools} | |
def color_mapping(self) -> dict[str, str]: | |
return get_color_mapping( | |
[tool.name for tool in self.agent_executor.tools], | |
excluded_colors=["green", "red"], | |
) | |
def reset(self) -> None: | |
""" | |
Reset the iterator to its initial state, clearing intermediate steps, | |
iterations, and time elapsed. | |
""" | |
logger.debug("(Re)setting AgentExecutorIterator to fresh state") | |
self.intermediate_steps: list[tuple[AgentAction, str]] = [] | |
self.iterations = 0 | |
# maybe better to start these on the first __anext__ call? | |
self.time_elapsed = 0.0 | |
self.start_time = time.time() | |
self._final_outputs = None | |
def update_iterations(self) -> None: | |
""" | |
Increment the number of iterations and update the time elapsed. | |
""" | |
self.iterations += 1 | |
self.time_elapsed = time.time() - self.start_time | |
logger.debug( | |
f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)" | |
) | |
def raise_stopiteration(self, output: Any) -> NoReturn: | |
""" | |
Raise a StopIteration exception with the given output. | |
""" | |
logger.debug("Chain end: stop iteration") | |
raise StopIteration(output) | |
async def raise_stopasynciteration(self, output: Any) -> NoReturn: | |
""" | |
Raise a StopAsyncIteration exception with the given output. | |
Close the timeout context manager. | |
""" | |
logger.debug("Chain end: stop async iteration") | |
if self.timeout_manager is not None: | |
await self.timeout_manager.__aexit__(None, None, None) | |
raise StopAsyncIteration(output) | |
def final_outputs(self) -> Optional[dict[str, Any]]: | |
return self._final_outputs | |
def final_outputs(self, outputs: Optional[Dict[str, Any]]) -> None: | |
# have access to intermediate steps by design in iterator, | |
# so return only outputs may as well always be true. | |
self._final_outputs = None | |
if outputs: | |
prepared_outputs: dict[str, Any] = self.agent_executor.prep_outputs( | |
self.inputs, outputs, return_only_outputs=True | |
) | |
if self.include_run_info and self.run_manager is not None: | |
logger.debug("Assign run key") | |
prepared_outputs[RUN_KEY] = RunInfo(run_id=self.run_manager.run_id) | |
self._final_outputs = prepared_outputs | |
def __iter__(self: "AgentExecutorIterator") -> "AgentExecutorIterator": | |
logger.debug("Initialising AgentExecutorIterator") | |
self.reset() | |
assert isinstance(self.callback_manager, CallbackManager) | |
self.run_manager = self.callback_manager.on_chain_start( | |
dumpd(self.agent_executor), | |
self.inputs, | |
) | |
return self | |
def __aiter__(self) -> "AgentExecutorIterator": | |
""" | |
N.B. __aiter__ must be a normal method, so need to initialise async run manager | |
on first __anext__ call where we can await it | |
""" | |
logger.debug("Initialising AgentExecutorIterator (async)") | |
self.reset() | |
if self.agent_executor.max_execution_time: | |
self.timeout_manager = asyncio_timeout( | |
self.agent_executor.max_execution_time | |
) | |
else: | |
self.timeout_manager = None | |
return self | |
def _on_first_step(self) -> None: | |
""" | |
Perform any necessary setup for the first step of the synchronous iterator. | |
""" | |
pass | |
async def _on_first_async_step(self) -> None: | |
""" | |
Perform any necessary setup for the first step of the asynchronous iterator. | |
""" | |
# on first step, need to await callback manager and start async timeout ctxmgr | |
if self.iterations == 0: | |
assert isinstance(self.callback_manager, AsyncCallbackManager) | |
self.run_manager = await self.callback_manager.on_chain_start( | |
dumpd(self.agent_executor), | |
self.inputs, | |
) | |
if self.timeout_manager: | |
await self.timeout_manager.__aenter__() | |
def __next__(self) -> dict[str, Any]: | |
""" | |
AgentExecutor AgentExecutorIterator | |
__call__ (__iter__ ->) __next__ | |
_call <=> _call_next | |
_take_next_step _take_next_step | |
""" | |
# first step | |
if self.iterations == 0: | |
self._on_first_step() | |
# N.B. timeout taken care of by "_should_continue" in sync case | |
try: | |
return self._call_next() | |
except StopIteration: | |
raise | |
except BaseException as e: | |
if self.run_manager: | |
self.run_manager.on_chain_error(e) | |
raise | |
async def __anext__(self) -> dict[str, Any]: | |
""" | |
AgentExecutor AgentExecutorIterator | |
acall (__aiter__ ->) __anext__ | |
_acall <=> _acall_next | |
_atake_next_step _atake_next_step | |
""" | |
if self.iterations == 0: | |
await self._on_first_async_step() | |
try: | |
return await self._acall_next() | |
except StopAsyncIteration: | |
raise | |
except (TimeoutError, CancelledError): | |
await self.timeout_manager.__aexit__(None, None, None) | |
self.timeout_manager = None | |
return await self._astop() | |
except BaseException as e: | |
if self.run_manager: | |
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun) | |
await self.run_manager.on_chain_error(e) | |
raise | |
def _execute_next_step( | |
self, run_manager: Optional[CallbackManagerForChainRun] | |
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: | |
""" | |
Execute the next step in the chain using the | |
AgentExecutor's _take_next_step method. | |
""" | |
return self.agent_executor._take_next_step( | |
self.name_to_tool_map, | |
self.color_mapping, | |
self.inputs, | |
self.intermediate_steps, | |
run_manager=run_manager, | |
) | |
async def _execute_next_async_step( | |
self, run_manager: Optional[AsyncCallbackManagerForChainRun] | |
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: | |
""" | |
Execute the next step in the chain using the | |
AgentExecutor's _atake_next_step method. | |
""" | |
return await self.agent_executor._atake_next_step( | |
self.name_to_tool_map, | |
self.color_mapping, | |
self.inputs, | |
self.intermediate_steps, | |
run_manager=run_manager, | |
) | |
def _process_next_step_output( | |
self, | |
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], | |
run_manager: Optional[CallbackManagerForChainRun], | |
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]: | |
""" | |
Process the output of the next step, | |
handling AgentFinish and tool return cases. | |
""" | |
logger.debug("Processing output of Agent loop step") | |
if isinstance(next_step_output, AgentFinish): | |
logger.debug( | |
"Hit AgentFinish: _return -> on_chain_end -> run final output logic" | |
) | |
output = self.agent_executor._return( | |
next_step_output, self.intermediate_steps, run_manager=run_manager | |
) | |
if self.run_manager: | |
self.run_manager.on_chain_end(output) | |
self.final_outputs = output | |
return output | |
self.intermediate_steps.extend(next_step_output) | |
logger.debug("Updated intermediate_steps with step output") | |
# Check for tool return | |
if len(next_step_output) == 1: | |
next_step_action = next_step_output[0] | |
tool_return = self.agent_executor._get_tool_return(next_step_action) | |
if tool_return is not None: | |
output = self.agent_executor._return( | |
tool_return, self.intermediate_steps, run_manager=run_manager | |
) | |
if self.run_manager: | |
self.run_manager.on_chain_end(output) | |
self.final_outputs = output | |
return output | |
output = {"intermediate_step": next_step_output} | |
return output | |
async def _aprocess_next_step_output( | |
self, | |
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], | |
run_manager: Optional[AsyncCallbackManagerForChainRun], | |
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]: | |
""" | |
Process the output of the next async step, | |
handling AgentFinish and tool return cases. | |
""" | |
logger.debug("Processing output of async Agent loop step") | |
if isinstance(next_step_output, AgentFinish): | |
logger.debug( | |
"Hit AgentFinish: _areturn -> on_chain_end -> run final output logic" | |
) | |
output = await self.agent_executor._areturn( | |
next_step_output, self.intermediate_steps, run_manager=run_manager | |
) | |
if run_manager: | |
await run_manager.on_chain_end(output) | |
self.final_outputs = output | |
return output | |
self.intermediate_steps.extend(next_step_output) | |
logger.debug("Updated intermediate_steps with step output") | |
# Check for tool return | |
if len(next_step_output) == 1: | |
next_step_action = next_step_output[0] | |
tool_return = self.agent_executor._get_tool_return(next_step_action) | |
if tool_return is not None: | |
output = await self.agent_executor._areturn( | |
tool_return, self.intermediate_steps, run_manager=run_manager | |
) | |
if run_manager: | |
await run_manager.on_chain_end(output) | |
self.final_outputs = output | |
return output | |
output = {"intermediate_step": next_step_output} | |
return output | |
def _stop(self) -> dict[str, Any]: | |
""" | |
Stop the iterator and raise a StopIteration exception with the stopped response. | |
""" | |
logger.warning("Stopping agent prematurely due to triggering stop condition") | |
# this manually constructs agent finish with output key | |
output = self.agent_executor.agent.return_stopped_response( | |
self.agent_executor.early_stopping_method, | |
self.intermediate_steps, | |
**self.inputs, | |
) | |
assert ( | |
isinstance(self.run_manager, CallbackManagerForChainRun) | |
or self.run_manager is None | |
) | |
returned_output = self.agent_executor._return( | |
output, self.intermediate_steps, run_manager=self.run_manager | |
) | |
self.final_outputs = returned_output | |
return returned_output | |
async def _astop(self) -> dict[str, Any]: | |
""" | |
Stop the async iterator and raise a StopAsyncIteration exception with | |
the stopped response. | |
""" | |
logger.warning("Stopping agent prematurely due to triggering stop condition") | |
output = self.agent_executor.agent.return_stopped_response( | |
self.agent_executor.early_stopping_method, | |
self.intermediate_steps, | |
**self.inputs, | |
) | |
assert ( | |
isinstance(self.run_manager, AsyncCallbackManagerForChainRun) | |
or self.run_manager is None | |
) | |
returned_output = await self.agent_executor._areturn( | |
output, self.intermediate_steps, run_manager=self.run_manager | |
) | |
self.final_outputs = returned_output | |
return returned_output | |
def _call_next(self) -> dict[str, Any]: | |
""" | |
Perform a single iteration of the synchronous AgentExecutorIterator. | |
""" | |
# final output already reached: stopiteration (final output) | |
if self.final_outputs is not None: | |
self.raise_stopiteration(self.final_outputs) | |
# timeout/max iterations: stopiteration (stopped response) | |
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed): | |
return self._stop() | |
assert ( | |
isinstance(self.run_manager, CallbackManagerForChainRun) | |
or self.run_manager is None | |
) | |
next_step_output = self._execute_next_step(self.run_manager) | |
output = self._process_next_step_output(next_step_output, self.run_manager) | |
self.update_iterations() | |
return output | |
async def _acall_next(self) -> dict[str, Any]: | |
""" | |
Perform a single iteration of the asynchronous AgentExecutorIterator. | |
""" | |
# final output already reached: stopiteration (final output) | |
if self.final_outputs is not None: | |
await self.raise_stopasynciteration(self.final_outputs) | |
# timeout/max iterations: stopiteration (stopped response) | |
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed): | |
return await self._astop() | |
assert ( | |
isinstance(self.run_manager, AsyncCallbackManagerForChainRun) | |
or self.run_manager is None | |
) | |
next_step_output = await self._execute_next_async_step(self.run_manager) | |
output = await self._aprocess_next_step_output( | |
next_step_output, self.run_manager | |
) | |
self.update_iterations() | |
return output | |