Spaces:
Sleeping
Sleeping
# coding: utf-8 | |
# Copyright (c) 2025 inclusionAI. | |
import abc | |
from typing import AsyncGenerator, Tuple | |
from aworld.agents.loop_llm_agent import LoopableAgent | |
from aworld.core.agent.base import is_agent, AgentFactory | |
from aworld.core.agent.swarm import GraphBuildType | |
from aworld.core.common import ActionModel, Observation, TaskItem | |
from aworld.core.event.base import Message, Constants, TopicType | |
from aworld.logs.util import logger | |
from aworld.runners.handler.base import DefaultHandler | |
from aworld.runners.handler.tool import DefaultToolHandler | |
from aworld.runners.utils import endless_detect | |
from aworld.output.base import StepOutput | |
class AgentHandler(DefaultHandler): | |
__metaclass__ = abc.ABCMeta | |
def __init__(self, runner: 'TaskEventRunner'): | |
self.swarm = runner.swarm | |
self.endless_threshold = runner.endless_threshold | |
self.agent_calls = [] | |
def name(cls): | |
return "_agents_handler" | |
class DefaultAgentHandler(AgentHandler): | |
async def handle(self, message: Message) -> AsyncGenerator[Message, None]: | |
if message.category != Constants.AGENT: | |
if message.sender in self.swarm.agents and message.sender in AgentFactory: | |
if self.agent_calls: | |
if self.agent_calls[-1] != message.sender: | |
self.agent_calls.append(message.sender) | |
else: | |
self.agent_calls.append(message.sender) | |
return | |
headers = {"context": message.context} | |
session_id = message.session_id | |
data = message.payload | |
if not data: | |
# error message, p2p | |
yield Message( | |
category=Constants.OUTPUT, | |
payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}", | |
step_num=0, | |
data="no data to process."), | |
sender=self.name(), | |
session_id=session_id, | |
headers=headers | |
) | |
yield Message( | |
category=Constants.TASK, | |
payload=TaskItem(msg="no data to process.", data=data, stop=True), | |
sender=self.name(), | |
session_id=session_id, | |
topic=TopicType.ERROR, | |
headers=headers | |
) | |
return | |
if isinstance(data, Tuple) and isinstance(data[0], Observation): | |
data = data[0] | |
message.payload = data | |
# data is Observation | |
if isinstance(data, Observation): | |
if not self.swarm: | |
msg = Message( | |
category=Constants.TASK, | |
payload=data.content, | |
sender=data.observer, | |
session_id=session_id, | |
topic=TopicType.FINISHED, | |
headers=headers | |
) | |
logger.info(f"agent handler send finished message: {msg}") | |
yield msg | |
return | |
agent = self.swarm.agents.get(message.receiver) | |
# agent + tool completion protocol. | |
if agent and agent.finished and data.info.get('done'): | |
self.swarm.cur_step += 1 | |
if agent.id() == self.swarm.communicate_agent.id(): | |
msg = Message( | |
category=Constants.TASK, | |
payload=data.content, | |
sender=agent.id(), | |
session_id=session_id, | |
topic=TopicType.FINISHED, | |
headers=headers | |
) | |
logger.info(f"agent handler send finished message: {msg}") | |
yield msg | |
else: | |
msg = Message( | |
category=Constants.AGENT, | |
payload=Observation(content=data.content), | |
sender=agent.id(), | |
session_id=session_id, | |
receiver=self.swarm.communicate_agent.id(), | |
headers=headers | |
) | |
logger.info(f"agent handler send agent message: {msg}") | |
yield msg | |
else: | |
if data.info.get('done'): | |
agent_name = self.agent_calls[-1] | |
async for event in self._stop_check(ActionModel(agent_name=agent_name, policy_info=data.content), | |
message): | |
yield event | |
return | |
logger.info(f"agent handler send observation message: {message}") | |
yield message | |
return | |
# data is List[ActionModel] | |
for action in data: | |
if not isinstance(action, ActionModel): | |
# error message, p2p | |
yield Message( | |
category=Constants.OUTPUT, | |
payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}", | |
step_num=0, | |
data="action not a ActionModel."), | |
sender=self.name(), | |
session_id=session_id, | |
headers=headers | |
) | |
msg = Message( | |
category=Constants.TASK, | |
payload=TaskItem(msg="action not a ActionModel.", data=data, stop=True), | |
sender=self.name(), | |
session_id=session_id, | |
topic=TopicType.ERROR, | |
headers=headers | |
) | |
logger.info(f"agent handler send task message: {msg}") | |
yield msg | |
return | |
tools = [] | |
agents = [] | |
for action in data: | |
if is_agent(action): | |
agents.append(action) | |
else: | |
tools.append(action) | |
if tools: | |
msg = Message( | |
category=Constants.TOOL, | |
payload=tools, | |
sender=self.name(), | |
session_id=session_id, | |
receiver=DefaultToolHandler.name(), | |
headers=headers | |
) | |
logger.info(f"agent handler send tool message: {msg}") | |
yield msg | |
else: | |
yield Message( | |
category=Constants.OUTPUT, | |
payload=StepOutput.build_finished_output(name=f"{message.caller or self.name()}", | |
step_num=0), | |
sender=self.name(), | |
receiver=agents[0].tool_name, | |
session_id=session_id, | |
headers=headers | |
) | |
for agent in agents: | |
async for event in self._agent(agent, message): | |
logger.info(f"agent handler send message: {event}") | |
yield event | |
async def _agent(self, action: ActionModel, message: Message): | |
self.agent_calls.append(action.agent_name) | |
agent = self.swarm.agents.get(action.agent_name) | |
# be handoff | |
agent_name = action.tool_name | |
if not agent_name: | |
async for event in self._stop_check(action, message): | |
yield event | |
return | |
headers = {"context": message.context} | |
session_id = message.session_id | |
cur_agent = self.swarm.agents.get(agent_name) | |
if not cur_agent or not agent: | |
yield Message( | |
category=Constants.TASK, | |
payload=TaskItem(msg=f"Can not find {agent_name} or {action.agent_name} agent in swarm.", | |
data=action, | |
stop=True), | |
sender=self.name(), | |
session_id=session_id, | |
topic=TopicType.ERROR, | |
headers=headers | |
) | |
return | |
cur_agent._finished = False | |
con = action.policy_info | |
if action.params and 'content' in action.params: | |
con = action.params['content'] | |
observation = Observation(content=con, observer=agent.id(), from_agent_name=agent.id()) | |
if agent.handoffs and agent_name not in agent.handoffs: | |
if message.caller: | |
message.receiver = message.caller | |
message.caller = '' | |
yield message | |
else: | |
yield Message(category=Constants.TASK, | |
payload=TaskItem(msg=f"Can not handoffs {agent_name} agent ", data=observation), | |
sender=self.name(), | |
session_id=session_id, | |
topic=TopicType.RERUN, | |
headers=headers) | |
return | |
yield Message( | |
category=Constants.AGENT, | |
payload=observation, | |
caller=message.caller, | |
sender=action.agent_name, | |
session_id=session_id, | |
receiver=action.tool_name, | |
headers=headers | |
) | |
async def _stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]: | |
if GraphBuildType.WORKFLOW.value != self.swarm.build_type: | |
async for event in self._social_stop_check(action, message): | |
yield event | |
else: | |
if self.swarm.has_cycle: | |
async for event in self._loop_sequence_stop_check(action, message): | |
yield event | |
else: | |
async for event in self._sequence_stop_check(action, message): | |
yield event | |
async def _sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]: | |
headers = {"context": message.context} | |
session_id = message.session_id | |
agent = self.swarm.agents.get(action.agent_name) | |
ordered_agents = self.swarm.ordered_agents | |
idx = next((i for i, x in enumerate(ordered_agents) if x == agent), -1) | |
if idx == -1: | |
yield Message( | |
category=Constants.TASK, | |
payload=action, | |
sender=self.name(), | |
session_id=session_id, | |
topic=TopicType.ERROR, | |
headers=headers | |
) | |
return | |
# The last agent | |
if idx == len(self.swarm.ordered_agents) - 1: | |
receiver = None | |
# agent loop | |
if isinstance(agent, LoopableAgent): | |
agent.cur_run_times += 1 | |
if not agent.finished: | |
receiver = agent.goto | |
if receiver: | |
yield Message( | |
category=Constants.AGENT, | |
payload=Observation(content=action.policy_info), | |
sender=agent.id(), | |
session_id=session_id, | |
receiver=receiver, | |
headers=headers | |
) | |
else: | |
logger.info(f"execute loop {self.swarm.cur_step}.") | |
yield Message( | |
category=Constants.TASK, | |
payload=action.policy_info, | |
sender=agent.id(), | |
session_id=session_id, | |
topic=TopicType.FINISHED, | |
headers=headers | |
) | |
return | |
# loop agent type | |
if isinstance(agent, LoopableAgent): | |
agent.cur_run_times += 1 | |
if agent.finished: | |
receiver = self.swarm.ordered_agents[idx + 1].id() | |
else: | |
receiver = agent.goto | |
else: | |
# means the loop finished | |
receiver = self.swarm.ordered_agents[idx + 1].id() | |
yield Message( | |
category=Constants.AGENT, | |
payload=Observation(content=action.policy_info), | |
sender=agent.id(), | |
session_id=session_id, | |
receiver=receiver, | |
headers=headers | |
) | |
async def _loop_sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]: | |
headers = {"context": message.context} | |
session_id = message.session_id | |
agent = self.swarm.agents.get(action.agent_name) | |
idx = next((i for i, x in enumerate(self.swarm.ordered_agents) if x == agent), -1) | |
if idx == -1: | |
# unknown agent, means something wrong | |
yield Message( | |
category=Constants.TASK, | |
payload=action, | |
sender=self.name(), | |
session_id=session_id, | |
topic=TopicType.ERROR, | |
headers=headers | |
) | |
return | |
if idx == len(self.swarm.ordered_agents) - 1: | |
# supported sequence loop | |
if self.swarm.cur_step >= self.swarm.max_steps: | |
receiver = None | |
# agent loop | |
if isinstance(agent, LoopableAgent): | |
agent.cur_run_times += 1 | |
if not agent.finished: | |
receiver = agent.goto | |
if receiver: | |
yield Message( | |
category=Constants.AGENT, | |
payload=Observation(content=action.policy_info), | |
sender=agent.id(), | |
session_id=session_id, | |
receiver=receiver, | |
headers=headers | |
) | |
else: | |
# means the task finished | |
yield Message( | |
category=Constants.TASK, | |
payload=action.policy_info, | |
sender=agent.id(), | |
session_id=session_id, | |
topic=TopicType.FINISHED, | |
headers=headers | |
) | |
else: | |
self.swarm.cur_step += 1 | |
logger.info(f"execute loop {self.swarm.cur_step}.") | |
yield Message( | |
category=Constants.TASK, | |
payload='', | |
sender=agent.id(), | |
session_id=session_id, | |
topic=TopicType.START, | |
headers=headers | |
) | |
return | |
if isinstance(agent, LoopableAgent): | |
agent.cur_run_times += 1 | |
if agent.finished: | |
receiver = self.swarm.ordered_agents[idx + 1].id() | |
else: | |
receiver = agent.goto | |
else: | |
# means the loop finished | |
receiver = self.swarm.ordered_agents[idx + 1].id() | |
yield Message( | |
category=Constants.AGENT, | |
payload=Observation(content=action.policy_info), | |
sender=agent.name(), | |
session_id=session_id, | |
receiver=receiver, | |
headers=headers | |
) | |
async def _social_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]: | |
headers = {"context": message.context} | |
agent = self.swarm.agents.get(action.agent_name) | |
caller = message.caller | |
session_id = message.session_id | |
if endless_detect(self.agent_calls, | |
endless_threshold=self.endless_threshold, | |
root_agent_name=self.swarm.communicate_agent.id()): | |
yield Message( | |
category=Constants.TASK, | |
payload=action.policy_info, | |
sender=agent.id(), | |
session_id=session_id, | |
topic=TopicType.FINISHED, | |
headers=headers | |
) | |
return | |
if not caller or caller == self.swarm.communicate_agent.id(): | |
if self.swarm.cur_step >= self.swarm.max_steps or self.swarm.finished: | |
yield Message( | |
category=Constants.TASK, | |
payload=action.policy_info, | |
sender=agent.id(), | |
session_id=session_id, | |
topic=TopicType.FINISHED, | |
headers=headers | |
) | |
else: | |
self.swarm.cur_step += 1 | |
logger.info(f"execute loop {self.swarm.cur_step}.") | |
yield Message( | |
category=Constants.AGENT, | |
payload=Observation(content=action.policy_info), | |
sender=agent.id(), | |
session_id=session_id, | |
receiver=self.swarm.communicate_agent.id(), | |
headers=headers | |
) | |
else: | |
idx = 0 | |
for idx, name in enumerate(self.agent_calls[::-1]): | |
if name == agent.id(): | |
break | |
idx = len(self.agent_calls) - idx - 1 | |
if idx: | |
caller = self.agent_calls[idx - 1] | |
yield Message( | |
category=Constants.AGENT, | |
payload=Observation(content=action.policy_info), | |
sender=agent.id(), | |
session_id=session_id, | |
receiver=caller, | |
headers=headers | |
) | |