webui / langchain /agents /agent_iterator.py
zhangyi617's picture
Upload folder using huggingface_hub
129cd69
from __future__ import annotations
import logging
import time
from abc import ABC, abstractmethod
from asyncio import CancelledError
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
NoReturn,
Optional,
Tuple,
Type,
Union,
)
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.load.dump import dumpd
from langchain_core.outputs import RunInfo
from langchain_core.utils.input import get_color_mapping
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.schema import RUN_KEY
from langchain.tools import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
logger = logging.getLogger(__name__)
class BaseAgentExecutorIterator(ABC):
"""Base class for AgentExecutorIterator."""
@abstractmethod
def build_callback_manager(self) -> None:
pass
def rebuild_callback_manager_on_set(
setter_method: Callable[..., None]
) -> Callable[..., None]:
"""Decorator to force setters to rebuild callback mgr"""
@wraps(setter_method)
def wrapper(self: BaseAgentExecutorIterator, *args: Any, **kwargs: Any) -> None:
setter_method(self, *args, **kwargs)
self.build_callback_manager()
return wrapper
class AgentExecutorIterator(BaseAgentExecutorIterator):
"""Iterator for AgentExecutor."""
def __init__(
self,
agent_executor: AgentExecutor,
inputs: Any,
callbacks: Callbacks = None,
*,
tags: Optional[list[str]] = None,
include_run_info: bool = False,
async_: bool = False,
):
"""
Initialize the AgentExecutorIterator with the given AgentExecutor,
inputs, and optional callbacks.
"""
self._agent_executor = agent_executor
self.inputs = inputs
self.async_ = async_
# build callback manager on tags setter
self._callbacks = callbacks
self.tags = tags
self.include_run_info = include_run_info
self.run_manager = None
self.reset()
_callback_manager: Union[AsyncCallbackManager, CallbackManager]
_inputs: dict[str, str]
_final_outputs: Optional[dict[str, str]]
run_manager: Optional[
Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun]
]
timeout_manager: Any # TODO: Fix a type here; the shim makes it tricky.
@property
def inputs(self) -> dict[str, str]:
return self._inputs
@inputs.setter
def inputs(self, inputs: Any) -> None:
self._inputs = self.agent_executor.prep_inputs(inputs)
@property
def callbacks(self) -> Callbacks:
return self._callbacks
@callbacks.setter
@rebuild_callback_manager_on_set
def callbacks(self, callbacks: Callbacks) -> None:
"""When callbacks are changed after __init__, rebuild callback mgr"""
self._callbacks = callbacks
@property
def tags(self) -> Optional[List[str]]:
return self._tags
@tags.setter
@rebuild_callback_manager_on_set
def tags(self, tags: Optional[List[str]]) -> None:
"""When tags are changed after __init__, rebuild callback mgr"""
self._tags = tags
@property
def agent_executor(self) -> AgentExecutor:
return self._agent_executor
@agent_executor.setter
@rebuild_callback_manager_on_set
def agent_executor(self, agent_executor: AgentExecutor) -> None:
self._agent_executor = agent_executor
# force re-prep inputs in case agent_executor's prep_inputs fn changed
self.inputs = self.inputs
@property
def callback_manager(self) -> Union[AsyncCallbackManager, CallbackManager]:
return self._callback_manager
def build_callback_manager(self) -> None:
"""
Create and configure the callback manager based on the current
callbacks and tags.
"""
CallbackMgr: Union[Type[AsyncCallbackManager], Type[CallbackManager]] = (
AsyncCallbackManager if self.async_ else CallbackManager
)
self._callback_manager = CallbackMgr.configure(
self.callbacks,
self.agent_executor.callbacks,
self.agent_executor.verbose,
self.tags,
self.agent_executor.tags,
)
@property
def name_to_tool_map(self) -> dict[str, BaseTool]:
return {tool.name: tool for tool in self.agent_executor.tools}
@property
def color_mapping(self) -> dict[str, str]:
return get_color_mapping(
[tool.name for tool in self.agent_executor.tools],
excluded_colors=["green", "red"],
)
def reset(self) -> None:
"""
Reset the iterator to its initial state, clearing intermediate steps,
iterations, and time elapsed.
"""
logger.debug("(Re)setting AgentExecutorIterator to fresh state")
self.intermediate_steps: list[tuple[AgentAction, str]] = []
self.iterations = 0
# maybe better to start these on the first __anext__ call?
self.time_elapsed = 0.0
self.start_time = time.time()
self._final_outputs = None
def update_iterations(self) -> None:
"""
Increment the number of iterations and update the time elapsed.
"""
self.iterations += 1
self.time_elapsed = time.time() - self.start_time
logger.debug(
f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)"
)
def raise_stopiteration(self, output: Any) -> NoReturn:
"""
Raise a StopIteration exception with the given output.
"""
logger.debug("Chain end: stop iteration")
raise StopIteration(output)
async def raise_stopasynciteration(self, output: Any) -> NoReturn:
"""
Raise a StopAsyncIteration exception with the given output.
Close the timeout context manager.
"""
logger.debug("Chain end: stop async iteration")
if self.timeout_manager is not None:
await self.timeout_manager.__aexit__(None, None, None)
raise StopAsyncIteration(output)
@property
def final_outputs(self) -> Optional[dict[str, Any]]:
return self._final_outputs
@final_outputs.setter
def final_outputs(self, outputs: Optional[Dict[str, Any]]) -> None:
# have access to intermediate steps by design in iterator,
# so return only outputs may as well always be true.
self._final_outputs = None
if outputs:
prepared_outputs: dict[str, Any] = self.agent_executor.prep_outputs(
self.inputs, outputs, return_only_outputs=True
)
if self.include_run_info and self.run_manager is not None:
logger.debug("Assign run key")
prepared_outputs[RUN_KEY] = RunInfo(run_id=self.run_manager.run_id)
self._final_outputs = prepared_outputs
def __iter__(self: "AgentExecutorIterator") -> "AgentExecutorIterator":
logger.debug("Initialising AgentExecutorIterator")
self.reset()
assert isinstance(self.callback_manager, CallbackManager)
self.run_manager = self.callback_manager.on_chain_start(
dumpd(self.agent_executor),
self.inputs,
)
return self
def __aiter__(self) -> "AgentExecutorIterator":
"""
N.B. __aiter__ must be a normal method, so need to initialise async run manager
on first __anext__ call where we can await it
"""
logger.debug("Initialising AgentExecutorIterator (async)")
self.reset()
if self.agent_executor.max_execution_time:
self.timeout_manager = asyncio_timeout(
self.agent_executor.max_execution_time
)
else:
self.timeout_manager = None
return self
def _on_first_step(self) -> None:
"""
Perform any necessary setup for the first step of the synchronous iterator.
"""
pass
async def _on_first_async_step(self) -> None:
"""
Perform any necessary setup for the first step of the asynchronous iterator.
"""
# on first step, need to await callback manager and start async timeout ctxmgr
if self.iterations == 0:
assert isinstance(self.callback_manager, AsyncCallbackManager)
self.run_manager = await self.callback_manager.on_chain_start(
dumpd(self.agent_executor),
self.inputs,
)
if self.timeout_manager:
await self.timeout_manager.__aenter__()
def __next__(self) -> dict[str, Any]:
"""
AgentExecutor AgentExecutorIterator
__call__ (__iter__ ->) __next__
_call <=> _call_next
_take_next_step _take_next_step
"""
# first step
if self.iterations == 0:
self._on_first_step()
# N.B. timeout taken care of by "_should_continue" in sync case
try:
return self._call_next()
except StopIteration:
raise
except BaseException as e:
if self.run_manager:
self.run_manager.on_chain_error(e)
raise
async def __anext__(self) -> dict[str, Any]:
"""
AgentExecutor AgentExecutorIterator
acall (__aiter__ ->) __anext__
_acall <=> _acall_next
_atake_next_step _atake_next_step
"""
if self.iterations == 0:
await self._on_first_async_step()
try:
return await self._acall_next()
except StopAsyncIteration:
raise
except (TimeoutError, CancelledError):
await self.timeout_manager.__aexit__(None, None, None)
self.timeout_manager = None
return await self._astop()
except BaseException as e:
if self.run_manager:
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
await self.run_manager.on_chain_error(e)
raise
def _execute_next_step(
self, run_manager: Optional[CallbackManagerForChainRun]
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""
Execute the next step in the chain using the
AgentExecutor's _take_next_step method.
"""
return self.agent_executor._take_next_step(
self.name_to_tool_map,
self.color_mapping,
self.inputs,
self.intermediate_steps,
run_manager=run_manager,
)
async def _execute_next_async_step(
self, run_manager: Optional[AsyncCallbackManagerForChainRun]
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""
Execute the next step in the chain using the
AgentExecutor's _atake_next_step method.
"""
return await self.agent_executor._atake_next_step(
self.name_to_tool_map,
self.color_mapping,
self.inputs,
self.intermediate_steps,
run_manager=run_manager,
)
def _process_next_step_output(
self,
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
run_manager: Optional[CallbackManagerForChainRun],
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]:
"""
Process the output of the next step,
handling AgentFinish and tool return cases.
"""
logger.debug("Processing output of Agent loop step")
if isinstance(next_step_output, AgentFinish):
logger.debug(
"Hit AgentFinish: _return -> on_chain_end -> run final output logic"
)
output = self.agent_executor._return(
next_step_output, self.intermediate_steps, run_manager=run_manager
)
if self.run_manager:
self.run_manager.on_chain_end(output)
self.final_outputs = output
return output
self.intermediate_steps.extend(next_step_output)
logger.debug("Updated intermediate_steps with step output")
# Check for tool return
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
tool_return = self.agent_executor._get_tool_return(next_step_action)
if tool_return is not None:
output = self.agent_executor._return(
tool_return, self.intermediate_steps, run_manager=run_manager
)
if self.run_manager:
self.run_manager.on_chain_end(output)
self.final_outputs = output
return output
output = {"intermediate_step": next_step_output}
return output
async def _aprocess_next_step_output(
self,
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
run_manager: Optional[AsyncCallbackManagerForChainRun],
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]:
"""
Process the output of the next async step,
handling AgentFinish and tool return cases.
"""
logger.debug("Processing output of async Agent loop step")
if isinstance(next_step_output, AgentFinish):
logger.debug(
"Hit AgentFinish: _areturn -> on_chain_end -> run final output logic"
)
output = await self.agent_executor._areturn(
next_step_output, self.intermediate_steps, run_manager=run_manager
)
if run_manager:
await run_manager.on_chain_end(output)
self.final_outputs = output
return output
self.intermediate_steps.extend(next_step_output)
logger.debug("Updated intermediate_steps with step output")
# Check for tool return
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
tool_return = self.agent_executor._get_tool_return(next_step_action)
if tool_return is not None:
output = await self.agent_executor._areturn(
tool_return, self.intermediate_steps, run_manager=run_manager
)
if run_manager:
await run_manager.on_chain_end(output)
self.final_outputs = output
return output
output = {"intermediate_step": next_step_output}
return output
def _stop(self) -> dict[str, Any]:
"""
Stop the iterator and raise a StopIteration exception with the stopped response.
"""
logger.warning("Stopping agent prematurely due to triggering stop condition")
# this manually constructs agent finish with output key
output = self.agent_executor.agent.return_stopped_response(
self.agent_executor.early_stopping_method,
self.intermediate_steps,
**self.inputs,
)
assert (
isinstance(self.run_manager, CallbackManagerForChainRun)
or self.run_manager is None
)
returned_output = self.agent_executor._return(
output, self.intermediate_steps, run_manager=self.run_manager
)
self.final_outputs = returned_output
return returned_output
async def _astop(self) -> dict[str, Any]:
"""
Stop the async iterator and raise a StopAsyncIteration exception with
the stopped response.
"""
logger.warning("Stopping agent prematurely due to triggering stop condition")
output = self.agent_executor.agent.return_stopped_response(
self.agent_executor.early_stopping_method,
self.intermediate_steps,
**self.inputs,
)
assert (
isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
or self.run_manager is None
)
returned_output = await self.agent_executor._areturn(
output, self.intermediate_steps, run_manager=self.run_manager
)
self.final_outputs = returned_output
return returned_output
def _call_next(self) -> dict[str, Any]:
"""
Perform a single iteration of the synchronous AgentExecutorIterator.
"""
# final output already reached: stopiteration (final output)
if self.final_outputs is not None:
self.raise_stopiteration(self.final_outputs)
# timeout/max iterations: stopiteration (stopped response)
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
return self._stop()
assert (
isinstance(self.run_manager, CallbackManagerForChainRun)
or self.run_manager is None
)
next_step_output = self._execute_next_step(self.run_manager)
output = self._process_next_step_output(next_step_output, self.run_manager)
self.update_iterations()
return output
async def _acall_next(self) -> dict[str, Any]:
"""
Perform a single iteration of the asynchronous AgentExecutorIterator.
"""
# final output already reached: stopiteration (final output)
if self.final_outputs is not None:
await self.raise_stopasynciteration(self.final_outputs)
# timeout/max iterations: stopiteration (stopped response)
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
return await self._astop()
assert (
isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
or self.run_manager is None
)
next_step_output = await self._execute_next_async_step(self.run_manager)
output = await self._aprocess_next_step_output(
next_step_output, self.run_manager
)
self.update_iterations()
return output