File size: 4,644 Bytes
d79f338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import abc
import time
import uuid
from typing import Callable, Any

from pydantic import BaseModel

import aworld.tools
from aworld.config import ConfigDict
from aworld.config.conf import ToolConfig
from aworld.core.agent.swarm import Swarm
from aworld.core.common import Observation
from aworld.core.context.base import Context
from aworld.core.context.session import Session
from aworld.core.tool.base import Tool, AsyncTool
from aworld.core.task import Task, TaskResponse, Runner
from aworld.logs.util import logger
from aworld import trace


class TaskRunner(Runner):
    """Task based runner api class."""
    __metaclass__ = abc.ABCMeta

    def __init__(self,
                 task: Task,
                 *,
                 agent_oriented: bool = True,
                 daemon_target: Callable[..., Any] = None):
        """Task runner initialize.

        Args:
            task: Task entity to be executed.
            agent_oriented: Is it an agent oriented task, default is True.
        """
        if task.tools is None:
            task.tools = []
        if task.tool_names is None:
            task.tool_names = []

        if agent_oriented:
            if not task.agent and not task.swarm:
                raise ValueError("agent and swarm all is None.")
            if task.agent and task.swarm:
                raise ValueError("agent and swarm choose one only.")
            if task.agent:
                # uniform agent
                task.swarm = Swarm(task.agent)

        if task.conf is None:
            task.conf = dict()
        if isinstance(task.conf, BaseModel):
            task.conf = task.conf.model_dump()
        check_input = task.conf.get("check_input", False)
        if check_input and not task.input:
            raise ValueError("task no input")

        self.context = task.context if task.context else Context.instance()
        self.task = task
        self.context.set_task(task)
        self.agent_oriented = agent_oriented
        self.daemon_target = daemon_target
        self._use_demon = False if not task.conf else task.conf.get('use_demon', False)
        self._exception = None
        self.start_time = time.time()
        self.step_agent_counter = {}

    async def pre_run(self):
        task = self.task
        self.swarm = task.swarm
        self.input = task.input
        self.outputs = task.outputs
        self.name = task.name
        self.conf = task.conf if task.conf else ConfigDict()
        self.tools = {tool.name(): tool for tool in task.tools} if task.tools else {}
        task.tool_names.extend(self.tools.keys())
        # lazy load
        self.tool_names = task.tool_names
        self.tools_conf = task.tools_conf
        if self.tools_conf is None:
            self.tools_conf = {}
        # mcp performs special process, use async only in the runn
        self.tools_conf['mcp'] = ToolConfig(use_async=True, name='mcp')
        self.endless_threshold = task.endless_threshold

        # build context
        if task.session_id:
            session = Session(session_id=task.session_id)
        else:
            session = Session(session_id=uuid.uuid1().hex)
        trace_id = uuid.uuid1().hex if trace.get_current_span() is None else trace.get_current_span().get_trace_id()
        self.context.task_id = self.name
        self.context.trace_id = trace_id
        self.context.session = session
        self.context.swarm = self.swarm

        # init tool state by reset(), and ignore them observation
        observation = None
        if self.tools:
            for _, tool in self.tools.items():
                # use the observation and info of the last one
                if isinstance(tool, Tool):
                    tool.context = self.context
                    observation, info = tool.reset()
                elif isinstance(tool, AsyncTool):
                    observation, info = await tool.reset()
                else:
                    logger.warning(f"Unsupported tool type: {tool}, will ignored.")

        if observation:
            if not observation.content:
                observation.content = self.input
        else:
            observation = Observation(content=self.input)

        self.observation = observation
        if self.swarm:
            self.swarm.event_driven = task.event_driven
            self.swarm.reset(observation.content, context=self.context, tools=self.tool_names)

    async def post_run(self):
        self.context.reset()

    @abc.abstractmethod
    async def do_run(self, context: Context = None) -> TaskResponse:
        """Task do run."""