Duibonduil's picture
Upload 5 files
bc5e560 verified
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import abc
from typing import AsyncGenerator, Tuple
from aworld.agents.loop_llm_agent import LoopableAgent
from aworld.core.agent.base import is_agent, AgentFactory
from aworld.core.agent.swarm import GraphBuildType
from aworld.core.common import ActionModel, Observation, TaskItem
from aworld.core.event.base import Message, Constants, TopicType
from aworld.logs.util import logger
from aworld.runners.handler.base import DefaultHandler
from aworld.runners.handler.tool import DefaultToolHandler
from aworld.runners.utils import endless_detect
from aworld.output.base import StepOutput
class AgentHandler(DefaultHandler):
__metaclass__ = abc.ABCMeta
def __init__(self, runner: 'TaskEventRunner'):
self.swarm = runner.swarm
self.endless_threshold = runner.endless_threshold
self.agent_calls = []
@classmethod
def name(cls):
return "_agents_handler"
class DefaultAgentHandler(AgentHandler):
async def handle(self, message: Message) -> AsyncGenerator[Message, None]:
if message.category != Constants.AGENT:
if message.sender in self.swarm.agents and message.sender in AgentFactory:
if self.agent_calls:
if self.agent_calls[-1] != message.sender:
self.agent_calls.append(message.sender)
else:
self.agent_calls.append(message.sender)
return
headers = {"context": message.context}
session_id = message.session_id
data = message.payload
if not data:
# error message, p2p
yield Message(
category=Constants.OUTPUT,
payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}",
step_num=0,
data="no data to process."),
sender=self.name(),
session_id=session_id,
headers=headers
)
yield Message(
category=Constants.TASK,
payload=TaskItem(msg="no data to process.", data=data, stop=True),
sender=self.name(),
session_id=session_id,
topic=TopicType.ERROR,
headers=headers
)
return
if isinstance(data, Tuple) and isinstance(data[0], Observation):
data = data[0]
message.payload = data
# data is Observation
if isinstance(data, Observation):
if not self.swarm:
msg = Message(
category=Constants.TASK,
payload=data.content,
sender=data.observer,
session_id=session_id,
topic=TopicType.FINISHED,
headers=headers
)
logger.info(f"agent handler send finished message: {msg}")
yield msg
return
agent = self.swarm.agents.get(message.receiver)
# agent + tool completion protocol.
if agent and agent.finished and data.info.get('done'):
self.swarm.cur_step += 1
if agent.id() == self.swarm.communicate_agent.id():
msg = Message(
category=Constants.TASK,
payload=data.content,
sender=agent.id(),
session_id=session_id,
topic=TopicType.FINISHED,
headers=headers
)
logger.info(f"agent handler send finished message: {msg}")
yield msg
else:
msg = Message(
category=Constants.AGENT,
payload=Observation(content=data.content),
sender=agent.id(),
session_id=session_id,
receiver=self.swarm.communicate_agent.id(),
headers=headers
)
logger.info(f"agent handler send agent message: {msg}")
yield msg
else:
if data.info.get('done'):
agent_name = self.agent_calls[-1]
async for event in self._stop_check(ActionModel(agent_name=agent_name, policy_info=data.content),
message):
yield event
return
logger.info(f"agent handler send observation message: {message}")
yield message
return
# data is List[ActionModel]
for action in data:
if not isinstance(action, ActionModel):
# error message, p2p
yield Message(
category=Constants.OUTPUT,
payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}",
step_num=0,
data="action not a ActionModel."),
sender=self.name(),
session_id=session_id,
headers=headers
)
msg = Message(
category=Constants.TASK,
payload=TaskItem(msg="action not a ActionModel.", data=data, stop=True),
sender=self.name(),
session_id=session_id,
topic=TopicType.ERROR,
headers=headers
)
logger.info(f"agent handler send task message: {msg}")
yield msg
return
tools = []
agents = []
for action in data:
if is_agent(action):
agents.append(action)
else:
tools.append(action)
if tools:
msg = Message(
category=Constants.TOOL,
payload=tools,
sender=self.name(),
session_id=session_id,
receiver=DefaultToolHandler.name(),
headers=headers
)
logger.info(f"agent handler send tool message: {msg}")
yield msg
else:
yield Message(
category=Constants.OUTPUT,
payload=StepOutput.build_finished_output(name=f"{message.caller or self.name()}",
step_num=0),
sender=self.name(),
receiver=agents[0].tool_name,
session_id=session_id,
headers=headers
)
for agent in agents:
async for event in self._agent(agent, message):
logger.info(f"agent handler send message: {event}")
yield event
async def _agent(self, action: ActionModel, message: Message):
self.agent_calls.append(action.agent_name)
agent = self.swarm.agents.get(action.agent_name)
# be handoff
agent_name = action.tool_name
if not agent_name:
async for event in self._stop_check(action, message):
yield event
return
headers = {"context": message.context}
session_id = message.session_id
cur_agent = self.swarm.agents.get(agent_name)
if not cur_agent or not agent:
yield Message(
category=Constants.TASK,
payload=TaskItem(msg=f"Can not find {agent_name} or {action.agent_name} agent in swarm.",
data=action,
stop=True),
sender=self.name(),
session_id=session_id,
topic=TopicType.ERROR,
headers=headers
)
return
cur_agent._finished = False
con = action.policy_info
if action.params and 'content' in action.params:
con = action.params['content']
observation = Observation(content=con, observer=agent.id(), from_agent_name=agent.id())
if agent.handoffs and agent_name not in agent.handoffs:
if message.caller:
message.receiver = message.caller
message.caller = ''
yield message
else:
yield Message(category=Constants.TASK,
payload=TaskItem(msg=f"Can not handoffs {agent_name} agent ", data=observation),
sender=self.name(),
session_id=session_id,
topic=TopicType.RERUN,
headers=headers)
return
yield Message(
category=Constants.AGENT,
payload=observation,
caller=message.caller,
sender=action.agent_name,
session_id=session_id,
receiver=action.tool_name,
headers=headers
)
async def _stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
if GraphBuildType.WORKFLOW.value != self.swarm.build_type:
async for event in self._social_stop_check(action, message):
yield event
else:
if self.swarm.has_cycle:
async for event in self._loop_sequence_stop_check(action, message):
yield event
else:
async for event in self._sequence_stop_check(action, message):
yield event
async def _sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
headers = {"context": message.context}
session_id = message.session_id
agent = self.swarm.agents.get(action.agent_name)
ordered_agents = self.swarm.ordered_agents
idx = next((i for i, x in enumerate(ordered_agents) if x == agent), -1)
if idx == -1:
yield Message(
category=Constants.TASK,
payload=action,
sender=self.name(),
session_id=session_id,
topic=TopicType.ERROR,
headers=headers
)
return
# The last agent
if idx == len(self.swarm.ordered_agents) - 1:
receiver = None
# agent loop
if isinstance(agent, LoopableAgent):
agent.cur_run_times += 1
if not agent.finished:
receiver = agent.goto
if receiver:
yield Message(
category=Constants.AGENT,
payload=Observation(content=action.policy_info),
sender=agent.id(),
session_id=session_id,
receiver=receiver,
headers=headers
)
else:
logger.info(f"execute loop {self.swarm.cur_step}.")
yield Message(
category=Constants.TASK,
payload=action.policy_info,
sender=agent.id(),
session_id=session_id,
topic=TopicType.FINISHED,
headers=headers
)
return
# loop agent type
if isinstance(agent, LoopableAgent):
agent.cur_run_times += 1
if agent.finished:
receiver = self.swarm.ordered_agents[idx + 1].id()
else:
receiver = agent.goto
else:
# means the loop finished
receiver = self.swarm.ordered_agents[idx + 1].id()
yield Message(
category=Constants.AGENT,
payload=Observation(content=action.policy_info),
sender=agent.id(),
session_id=session_id,
receiver=receiver,
headers=headers
)
async def _loop_sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
headers = {"context": message.context}
session_id = message.session_id
agent = self.swarm.agents.get(action.agent_name)
idx = next((i for i, x in enumerate(self.swarm.ordered_agents) if x == agent), -1)
if idx == -1:
# unknown agent, means something wrong
yield Message(
category=Constants.TASK,
payload=action,
sender=self.name(),
session_id=session_id,
topic=TopicType.ERROR,
headers=headers
)
return
if idx == len(self.swarm.ordered_agents) - 1:
# supported sequence loop
if self.swarm.cur_step >= self.swarm.max_steps:
receiver = None
# agent loop
if isinstance(agent, LoopableAgent):
agent.cur_run_times += 1
if not agent.finished:
receiver = agent.goto
if receiver:
yield Message(
category=Constants.AGENT,
payload=Observation(content=action.policy_info),
sender=agent.id(),
session_id=session_id,
receiver=receiver,
headers=headers
)
else:
# means the task finished
yield Message(
category=Constants.TASK,
payload=action.policy_info,
sender=agent.id(),
session_id=session_id,
topic=TopicType.FINISHED,
headers=headers
)
else:
self.swarm.cur_step += 1
logger.info(f"execute loop {self.swarm.cur_step}.")
yield Message(
category=Constants.TASK,
payload='',
sender=agent.id(),
session_id=session_id,
topic=TopicType.START,
headers=headers
)
return
if isinstance(agent, LoopableAgent):
agent.cur_run_times += 1
if agent.finished:
receiver = self.swarm.ordered_agents[idx + 1].id()
else:
receiver = agent.goto
else:
# means the loop finished
receiver = self.swarm.ordered_agents[idx + 1].id()
yield Message(
category=Constants.AGENT,
payload=Observation(content=action.policy_info),
sender=agent.name(),
session_id=session_id,
receiver=receiver,
headers=headers
)
async def _social_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
headers = {"context": message.context}
agent = self.swarm.agents.get(action.agent_name)
caller = message.caller
session_id = message.session_id
if endless_detect(self.agent_calls,
endless_threshold=self.endless_threshold,
root_agent_name=self.swarm.communicate_agent.id()):
yield Message(
category=Constants.TASK,
payload=action.policy_info,
sender=agent.id(),
session_id=session_id,
topic=TopicType.FINISHED,
headers=headers
)
return
if not caller or caller == self.swarm.communicate_agent.id():
if self.swarm.cur_step >= self.swarm.max_steps or self.swarm.finished:
yield Message(
category=Constants.TASK,
payload=action.policy_info,
sender=agent.id(),
session_id=session_id,
topic=TopicType.FINISHED,
headers=headers
)
else:
self.swarm.cur_step += 1
logger.info(f"execute loop {self.swarm.cur_step}.")
yield Message(
category=Constants.AGENT,
payload=Observation(content=action.policy_info),
sender=agent.id(),
session_id=session_id,
receiver=self.swarm.communicate_agent.id(),
headers=headers
)
else:
idx = 0
for idx, name in enumerate(self.agent_calls[::-1]):
if name == agent.id():
break
idx = len(self.agent_calls) - idx - 1
if idx:
caller = self.agent_calls[idx - 1]
yield Message(
category=Constants.AGENT,
payload=Observation(content=action.policy_info),
sender=agent.id(),
session_id=session_id,
receiver=caller,
headers=headers
)