|
|
|
|
|
import abc |
|
from typing import AsyncGenerator |
|
|
|
from aworld.core.agent.base import is_agent |
|
from aworld.core.common import ActionModel, TaskItem |
|
from aworld.core.event.base import Message, Constants, TopicType |
|
from aworld.core.tool.base import AsyncTool, Tool, ToolFactory |
|
from aworld.logs.util import logger |
|
from aworld.runners.handler.base import DefaultHandler |
|
|
|
|
|
class ToolHandler(DefaultHandler): |
|
__metaclass__ = abc.ABCMeta |
|
|
|
def __init__(self, runner: 'TaskEventRunner'): |
|
self.tools = runner.tools |
|
self.tools_conf = runner.tools_conf |
|
|
|
@classmethod |
|
def name(cls): |
|
return "_tool_handler" |
|
|
|
|
|
class DefaultToolHandler(ToolHandler): |
|
async def handle(self, message: Message) -> AsyncGenerator[Message, None]: |
|
if message.category != Constants.TOOL: |
|
return |
|
|
|
headers = {"context": message.context} |
|
|
|
data = message.payload |
|
if not data: |
|
|
|
yield Message( |
|
category=Constants.TASK, |
|
payload=TaskItem(msg="no data to process.", data=data, stop=True), |
|
sender='agent_handler', |
|
session_id=message.session_id, |
|
topic=TopicType.ERROR, |
|
headers=headers |
|
) |
|
return |
|
|
|
for action in data: |
|
if not isinstance(action, ActionModel): |
|
|
|
yield Message( |
|
category=Constants.TASK, |
|
payload=TaskItem(msg="action not a ActionModel.", data=data, stop=True), |
|
sender=self.name(), |
|
session_id=message.session_id, |
|
topic=TopicType.ERROR, |
|
headers=headers |
|
) |
|
return |
|
|
|
new_tools = dict() |
|
tool_mapping = dict() |
|
|
|
for act in data: |
|
if is_agent(act): |
|
logger.warning(f"somethings wrong, {act} is an agent.") |
|
continue |
|
|
|
if not self.tools or (self.tools and act.tool_name not in self.tools): |
|
|
|
conf = self.tools_conf.get(act.tool_name) |
|
tool = ToolFactory(act.tool_name, conf=conf, asyn=conf.use_async if conf else False) |
|
tool.event_driven = True |
|
if isinstance(tool, Tool): |
|
tool.reset() |
|
elif isinstance(tool, AsyncTool): |
|
await tool.reset() |
|
tool_mapping[act.tool_name] = [] |
|
self.tools[act.tool_name] = tool |
|
new_tools[act.tool_name] = tool |
|
if act.tool_name not in tool_mapping: |
|
tool_mapping[act.tool_name] = [] |
|
tool_mapping[act.tool_name].append(act) |
|
|
|
if new_tools: |
|
yield Message( |
|
category=Constants.TASK, |
|
payload=TaskItem(data=new_tools), |
|
sender=self.name(), |
|
session_id=message.session_id, |
|
topic=TopicType.SUBSCRIBE_TOOL, |
|
headers=headers |
|
) |
|
|
|
for tool_name, actions in tool_mapping.items(): |
|
if not (isinstance(self.tools[tool_name], Tool) or isinstance(self.tools[tool_name], AsyncTool)): |
|
logger.warning(f"Unsupported tool type: {self.tools[tool_name]}") |
|
continue |
|
|
|
|
|
yield Message( |
|
category=Constants.TOOL, |
|
payload=actions, |
|
sender=actions[0].agent_name if actions else '', |
|
session_id=message.session_id, |
|
receiver=tool_name, |
|
headers=headers |
|
) |
|
|