Duibonduil commited on
Commit
59c91c2
·
verified ·
1 Parent(s): b21a164

Upload 2 files

Browse files
aworld/runners/callback/decorator.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+
4
+ import functools
5
+ import inspect
6
+ import logging
7
+ from typing import Dict, Any, Callable, Optional, Union, Awaitable
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11
+ logger = logging.getLogger("callback_registry")
12
+
13
+
14
+ class CallbackRegistry:
15
+ """Callback function registry, used to manage and execute callback functions"""
16
+
17
+ # Registry for storing decorated callback functions
18
+ _registry: Dict[str, Callable] = {}
19
+
20
+ @classmethod
21
+ def register(cls, key_name: str, func: Callable) -> Callable:
22
+ """Register callback function to the registry
23
+
24
+ Args:
25
+ key_name: Unique identifier for the callback function
26
+ func: Callback function to register
27
+
28
+ Returns:
29
+ Registered callback function
30
+ """
31
+ # Check if a callback function with the same key_name already exists
32
+ if key_name in cls._registry:
33
+ existing_func = cls._registry[key_name]
34
+ logger.warning(
35
+ f"Callback function '{key_name}' already exists and will be overwritten! "
36
+ f"Original function: {existing_func.__name__ if hasattr(existing_func, '__name__') else str(existing_func)}, "
37
+ f"New function: {func.__name__ if hasattr(func, '__name__') else str(func)}"
38
+ )
39
+
40
+ cls._registry[key_name] = func
41
+ return func
42
+
43
+ @classmethod
44
+ def get(cls, key_name: str) -> Optional[Callable]:
45
+ """Get registered callback function by key_name
46
+
47
+ Args:
48
+ key_name: Unique identifier for the callback function
49
+
50
+ Returns:
51
+ Registered callback function, or None if not found
52
+ """
53
+ return cls._registry.get(key_name)
54
+
55
+ @classmethod
56
+ async def execute(
57
+ cls,
58
+ key_name: str,
59
+ tool: Any,
60
+ args: Dict[str, Any],
61
+ tool_context: Any,
62
+ tool_response: Optional[Dict[str, Any]] = None
63
+ ) -> Optional[Dict[str, Any]]:
64
+ """Execute registered callback function
65
+
66
+ Args:
67
+ key_name: Unique identifier for the callback function
68
+ tool: Tool object
69
+ args: Tool arguments
70
+ tool_context: Tool context
71
+ tool_response: Tool response (for post-callbacks)
72
+
73
+ Returns:
74
+ Return value of the callback function, or None if the callback function doesn't exist
75
+ """
76
+ callback = cls.get(key_name)
77
+ if not callback:
78
+ return None
79
+
80
+ # Determine parameters based on callback type
81
+ if tool_response is not None:
82
+ # Post-callback
83
+ result = callback(tool, args, tool_context, tool_response)
84
+ else:
85
+ # Pre-callback
86
+ result = callback(tool, args, tool_context)
87
+
88
+ # Handle asynchronous callbacks
89
+ if inspect.isawaitable(result):
90
+ result = await result
91
+
92
+ return result
93
+
94
+ @classmethod
95
+ def list(cls) -> Dict[str, str]:
96
+ """List all registered callback functions
97
+
98
+ Returns:
99
+ Dictionary containing callback function names and descriptions
100
+ """
101
+ return {
102
+ key: func.__name__ if hasattr(func, '__name__') else str(func)
103
+ for key, func in cls._registry.items()
104
+ }
105
+
106
+
107
+ def reg_callback(key_name: str):
108
+ """Decorator for registering callback functions
109
+
110
+ Args:
111
+ key_name: Unique identifier for the callback function
112
+
113
+ Returns:
114
+ Decorator function
115
+ """
116
+ def decorator(func):
117
+ # Register function to the global registry
118
+ CallbackRegistry.register(key_name, func)
119
+
120
+ @functools.wraps(func)
121
+ def wrapper(*args, **kwargs):
122
+ return func(*args, **kwargs)
123
+
124
+ return wrapper
125
+
126
+ return decorator
127
+
128
+
129
+ # For backward compatibility, keep these functions
130
+ def get_callback(key_name: str) -> Optional[Callable]:
131
+ """Get registered callback function by key_name
132
+
133
+ Args:
134
+ key_name: Unique identifier for the callback function
135
+
136
+ Returns:
137
+ Registered callback function, or None if not found
138
+ """
139
+ return CallbackRegistry.get(key_name)
140
+
141
+
142
+ async def execute_callback(
143
+ key_name: str,
144
+ tool: Any,
145
+ args: Dict[str, Any],
146
+ tool_context: Any,
147
+ tool_response: Optional[Dict[str, Any]] = None
148
+ ) -> Optional[Dict[str, Any]]:
149
+ """Execute registered callback function
150
+
151
+ Args:
152
+ key_name: Unique identifier for the callback function
153
+ tool: Tool object
154
+ args: Tool arguments
155
+ tool_context: Tool context
156
+ tool_response: Tool response (for post-callbacks)
157
+
158
+ Returns:
159
+ Return value of the callback function, or None if the callback function doesn't exist
160
+ """
161
+ return await CallbackRegistry.execute(key_name, tool, args, tool_context, tool_response)
162
+
163
+
164
+ def list_callbacks() -> Dict[str, str]:
165
+ """List all registered callback functions
166
+
167
+ Returns:
168
+ Dictionary containing callback function names and descriptions
169
+ """
170
+ return CallbackRegistry.list()
aworld/runners/callback/tool.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import AsyncGenerator
3
+
4
+ from aworld.core.task import TaskResponse
5
+ from aworld.models.model_response import ModelResponse
6
+ from aworld.output import Output, MessageOutput
7
+ from aworld.runners.callback.decorator import CallbackRegistry
8
+ from aworld.runners.handler.base import DefaultHandler
9
+ from aworld.core.common import TaskItem, Observation
10
+ from aworld.core.context.base import Context
11
+ from aworld.core.event.base import Message, Constants, TopicType
12
+ from aworld.logs.util import logger
13
+
14
+ class ToolCallbackHandler(DefaultHandler):
15
+ def __init__(self, runner):
16
+ self.runner = runner
17
+
18
+ async def handle(self, message):
19
+ if message.category != Constants.TOOL_CALLBACK:
20
+ return
21
+ logger.info(f"-------ToolCallbackHandler start handle message----: {message}")
22
+ outputs = self.runner.task.outputs
23
+ output = None
24
+ try:
25
+ payload = message.payload
26
+ if not payload or not payload[0]:
27
+ return
28
+ observation=payload[0]
29
+ if not isinstance(observation, Observation):
30
+ return
31
+ if not observation.action_result:
32
+ return
33
+ for res in observation.action_result:
34
+ if not res or not res.content or not res.tool_name or not res.action_name:
35
+ continue
36
+ callback_func = CallbackRegistry.get(res.tool_name + "__" + res.action_name)
37
+ if not callback_func:
38
+ continue
39
+ callback_func(res)
40
+ logger.info(f"-------ToolCallbackHandler callback_func-res: {res}")
41
+ logger.info(f"-------ToolCallbackHandler end handle message: {observation}")
42
+ except Exception as e:
43
+ # todo
44
+ logger.warning(f"ToolCallbackHandler Failed to parse payload: {e}")
45
+ yield Message(
46
+ category=Constants.TASK,
47
+ payload=TaskItem(msg="Failed to parse output.", data=payload, stop=True),
48
+ sender=self.name(),
49
+ session_id=Context.instance().session_id,
50
+ topic=TopicType.ERROR
51
+ )
52
+ finally:
53
+ #todo
54
+ if output:
55
+ if not output.metadata:
56
+ output.metadata = {}
57
+ output.metadata['sender'] = message.sender
58
+ output.metadata['receiver'] = message.receiver
59
+ await outputs.add_output(output)
60
+ # 1\Update the current message node status
61
+ # 2\Update the incoming message node status
62
+
63
+
64
+ return
65
+
66
+