Spaces:
Running
Running
| import inspect | |
| from collections import OrderedDict | |
| from typing import Callable, Dict, List, Union | |
| from lagent.actions.base_action import BaseAction | |
| from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction | |
| from lagent.hooks import Hook, RemovableHandle | |
| from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall | |
| from lagent.utils import create_object | |
| class ActionExecutor: | |
| """The action executor class. | |
| Args: | |
| actions (Union[BaseAction, List[BaseAction]]): The action or actions. | |
| invalid_action (BaseAction, optional): The invalid action. Defaults to | |
| InvalidAction(). | |
| no_action (BaseAction, optional): The no action. | |
| Defaults to NoAction(). | |
| finish_action (BaseAction, optional): The finish action. Defaults to | |
| FinishAction(). | |
| finish_in_action (bool, optional): Whether the finish action is in the | |
| action list. Defaults to False. | |
| """ | |
| def __init__( | |
| self, | |
| actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]], | |
| invalid_action: BaseAction = dict(type=InvalidAction), | |
| no_action: BaseAction = dict(type=NoAction), | |
| finish_action: BaseAction = dict(type=FinishAction), | |
| finish_in_action: bool = False, | |
| hooks: List[Dict] = None, | |
| ): | |
| if not isinstance(actions, list): | |
| actions = [actions] | |
| finish_action = create_object(finish_action) | |
| if finish_in_action: | |
| actions.append(finish_action) | |
| for i, action in enumerate(actions): | |
| actions[i] = create_object(action) | |
| self.actions = {action.name: action for action in actions} | |
| self.invalid_action = create_object(invalid_action) | |
| self.no_action = create_object(no_action) | |
| self.finish_action = finish_action | |
| self._hooks: Dict[int, Hook] = OrderedDict() | |
| if hooks: | |
| for hook in hooks: | |
| hook = create_object(hook) | |
| self.register_hook(hook) | |
| def description(self) -> List[Dict]: | |
| actions = [] | |
| for action_name, action in self.actions.items(): | |
| if action.is_toolkit: | |
| for api in action.description['api_list']: | |
| api_desc = api.copy() | |
| api_desc['name'] = f"{action_name}.{api_desc['name']}" | |
| actions.append(api_desc) | |
| else: | |
| action_desc = action.description.copy() | |
| actions.append(action_desc) | |
| return actions | |
| def __contains__(self, name: str): | |
| return name in self.actions | |
| def keys(self): | |
| return list(self.actions.keys()) | |
| def __setitem__(self, name: str, action: Union[BaseAction, Dict]): | |
| action = create_object(action) | |
| self.actions[action.name] = action | |
| def __delitem__(self, name: str): | |
| del self.actions[name] | |
| def forward(self, name, parameters, **kwargs) -> ActionReturn: | |
| action_name, api_name = ( | |
| name.split('.') if '.' in name else (name, 'run')) | |
| action_return: ActionReturn = ActionReturn() | |
| if action_name not in self: | |
| if name == self.no_action.name: | |
| action_return = self.no_action(parameters) | |
| elif name == self.finish_action.name: | |
| action_return = self.finish_action(parameters) | |
| else: | |
| action_return = self.invalid_action(parameters) | |
| else: | |
| action_return = self.actions[action_name](parameters, api_name) | |
| action_return.valid = ActionValidCode.OPEN | |
| return action_return | |
| def __call__(self, | |
| message: AgentMessage, | |
| session_id=0, | |
| **kwargs) -> AgentMessage: | |
| # message.receiver = self.name | |
| for hook in self._hooks.values(): | |
| result = hook.before_action(self, message, session_id) | |
| if result: | |
| message = result | |
| assert isinstance(message.content, FunctionCall) or ( | |
| isinstance(message.content, dict) and 'name' in message.content | |
| and 'parameters' in message.content) | |
| if isinstance(message.content, dict): | |
| name = message.content.get('name') | |
| parameters = message.content.get('parameters') | |
| else: | |
| name = message.content.name | |
| parameters = message.content.parameters | |
| response_message = self.forward( | |
| name=name, parameters=parameters, **kwargs) | |
| if not isinstance(response_message, AgentMessage): | |
| response_message = AgentMessage( | |
| sender=self.__class__.__name__, | |
| content=response_message, | |
| ) | |
| for hook in self._hooks.values(): | |
| result = hook.after_action(self, response_message, session_id) | |
| if result: | |
| response_message = result | |
| return response_message | |
| def register_hook(self, hook: Callable): | |
| handle = RemovableHandle(self._hooks) | |
| self._hooks[handle.id] = hook | |
| return handle | |
| class AsyncActionExecutor(ActionExecutor): | |
| async def forward(self, name, parameters, **kwargs) -> ActionReturn: | |
| action_name, api_name = ( | |
| name.split('.') if '.' in name else (name, 'run')) | |
| action_return: ActionReturn = ActionReturn() | |
| if action_name not in self: | |
| if name == self.no_action.name: | |
| action_return = self.no_action(parameters) | |
| elif name == self.finish_action.name: | |
| action_return = self.finish_action(parameters) | |
| else: | |
| action_return = self.invalid_action(parameters) | |
| else: | |
| action = self.actions[action_name] | |
| if inspect.iscoroutinefunction(action.__call__): | |
| action_return = await action(parameters, api_name) | |
| else: | |
| action_return = action(parameters, api_name) | |
| action_return.valid = ActionValidCode.OPEN | |
| return action_return | |
| async def __call__(self, | |
| message: AgentMessage, | |
| session_id=0, | |
| **kwargs) -> AgentMessage: | |
| # message.receiver = self.name | |
| for hook in self._hooks.values(): | |
| if inspect.iscoroutinefunction(hook.before_action): | |
| result = await hook.before_action(self, message, session_id) | |
| else: | |
| result = hook.before_action(self, message, session_id) | |
| if result: | |
| message = result | |
| assert isinstance(message.content, FunctionCall) or ( | |
| isinstance(message.content, dict) and 'name' in message.content | |
| and 'parameters' in message.content) | |
| if isinstance(message.content, dict): | |
| name = message.content.get('name') | |
| parameters = message.content.get('parameters') | |
| else: | |
| name = message.content.name | |
| parameters = message.content.parameters | |
| response_message = await self.forward( | |
| name=name, parameters=parameters, **kwargs) | |
| if not isinstance(response_message, AgentMessage): | |
| response_message = AgentMessage( | |
| sender=self.__class__.__name__, | |
| content=response_message, | |
| ) | |
| for hook in self._hooks.values(): | |
| if inspect.iscoroutinefunction(hook.after_action): | |
| result = await hook.after_action(self, response_message, | |
| session_id) | |
| else: | |
| result = hook.after_action(self, response_message, session_id) | |
| if result: | |
| response_message = result | |
| return response_message | |