File size: 16,520 Bytes
09321b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
import importlib
from typing import Dict, List, Optional, Union

from .agent_types import AgentType
from .llm import LLM
from .output_parser import OutputParser, get_output_parser
from .output_wrapper import display
from .prompt import PromptGenerator, get_prompt_generator
from .retrieve import KnowledgeRetrieval, ToolRetrieval
from .tools import TOOL_INFO_LIST


class AgentExecutor:
    def custom_run_init(self,
                        task: str,
                        remote: bool = False,
                        print_info: bool = False,
                        append_files: list = []) -> List[Dict]:

        tool_list = self.retrieve_tools(task)
        knowledge_list = self.get_knowledge(task)

        self.prompt_generator.init_prompt(
            task, tool_list, knowledge_list, append_files=append_files)
        function_list = self.prompt_generator.get_function_list(tool_list)

        llm_result, exec_result = '', ''

        idx = 0
        final_res = []

        return tool_list, knowledge_list, function_list, llm_result, exec_result, idx, final_res, remote, print_info

    def custom_gene_prompt(self, llm_result, exec_result, idx):
        idx += 1

        # generate prompt and call llm
        llm_artifacts = self.prompt_generator.generate(
            llm_result, exec_result)

        return llm_artifacts, idx

    def custom_parse_llm(self, llm_artifacts, llm_result, idx, final_res, remote, print_info):
        if print_info:
            print(f'|LLM inputs in round {idx}: {llm_artifacts}')

        # parse and get tool name and arguments
        try:
            action, action_args = self.output_parser.parse_response(
                llm_result)
        except ValueError as e:
            return [{'exec_result': f'{e}'}]

        if action is None:
            # in chat mode, the final result of last instructions should be updated to prompt history
            _ = self.prompt_generator.generate(llm_result, '')

            # for summarize
            # display(llm_result, {}, idx, self.agent_type)
            return [{'end_res': final_res}]

        if action in self.available_tool_list:
            action_args = self.parse_action_args(action_args)
            tool = self.tool_list[action]

            # TODO @wenmeng.zwm remove this hack logic for image generation
            if action == 'image_gen' and self.seed:
                action_args['seed'] = self.seed
            try:
                exec_result = tool(**action_args, remote=remote)
                if print_info:
                    print(f'|exec_result: {exec_result}')

                # parse exec result and store result to agent state
                final_res.append(exec_result)
                self.parse_exec_result(exec_result)
            except Exception as e:
                exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
                return [{'exec_result': exec_result}]
        else:
            exec_result = f"Unknown action: '{action}'. "
            return [{'exec_result': exec_result}]

        # display result
        # display(llm_result, exec_result, idx, self.agent_type)

        return [{'no_stop': {'llm_result': llm_result, 'exec_result': exec_result, 'idx': idx, 'final_res': final_res}}]

    def __init__(self,
                 llm: LLM,
                 tool_cfg: Optional[Dict] = {},
                 agent_type: AgentType = AgentType.DEFAULT,
                 additional_tool_list: Optional[Dict] = {},
                 prompt_generator: Optional[PromptGenerator] = None,
                 output_parser: Optional[OutputParser] = None,
                 tool_retrieval: Optional[Union[bool, ToolRetrieval]] = True,
                 knowledge_retrieval: Optional[KnowledgeRetrieval] = None):
        """
        the core class of ms agent. It is responsible for the interaction between user, llm and tools,
        and return the execution result to user.

        Args:
            llm (LLM): llm model, can be load from local or a remote server.
            tool_cfg (Optional[Dict]): cfg of default tools
            agent_type (AgentType, optional): agent type. Defaults to AgentType.DEFAULT, decide which type of agent
            reasoning type to use
            additional_tool_list (Optional[Dict], optional): user-defined additional tool list. Defaults to {}.
            prompt_generator (Optional[PromptGenerator], optional): this module is responsible for generating prompt
            according to interaction result. Defaults to use MSPromptGenerator.
            output_parser (Optional[OutputParser], optional): this module is responsible for parsing output of llm
            to executable actions. Defaults to use MsOutputParser.
            tool_retrieval (Optional[Union[bool, ToolRetrieval]], optional): Retrieve related tools by input task,
            since most of the tools may be useless for LLM in specific task.
            If it is bool type and is True, will use default tool_retrieval. Defaults to True.
            knowledge_retrieval (Optional[KnowledgeRetrieval], optional): If user want to use extra knowledge,
            this component can be used to retrieve related knowledge. Defaults to None.
        """

        self.llm = llm

        self.agent_type = agent_type
        self.llm.set_agent_type(agent_type)
        self.prompt_generator = prompt_generator or get_prompt_generator(
            agent_type)
        self.output_parser = output_parser or get_output_parser(agent_type)

        self._init_tools(tool_cfg, additional_tool_list)

        if isinstance(tool_retrieval, bool) and tool_retrieval:
            tool_retrieval = ToolRetrieval()
        self.tool_retrieval = tool_retrieval
        if self.tool_retrieval:
            self.tool_retrieval.construct(
                [str(t) for t in self.tool_list.values()])
        self.knowledge_retrieval = knowledge_retrieval
        self.reset()
        self.seed = None

    def _init_tools(self,
                    tool_cfg: Dict = {},
                    additional_tool_list: Dict = {}):
        """init tool list of agent. We provide a default tool list, which is initialized by a cfg file.
        user can also provide user-defined tools by additional_tool_list.
        The key of additional_tool_list is tool name, and the value is corresponding object.

        Args:
            tool_cfg (Dict): default tool cfg.
            additional_tool_list (Dict, optional): user-defined tools. Defaults to {}.
        """
        self.tool_list = {}
        tool_info_list = {**TOOL_INFO_LIST, **additional_tool_list}
        # tools_module = importlib.import_module('modelscope_agent.tools')
        from . import tools as tools_module

        for tool_name in tool_cfg.keys():
            if tool_cfg[tool_name].get('use', False):
                assert tool_name in tool_info_list, f'Invalid tool name: {tool_name}, ' \
                                                    f'available ones are: {tool_info_list.keys()}'
                tool_class_name = tool_info_list[tool_name]
                tool_class = getattr(tools_module, tool_class_name)
                tool_name = tool_class.name
                self.tool_list[tool_name] = tool_class(tool_cfg)

        self.tool_list = {**self.tool_list, **additional_tool_list}
        # self.available_tool_list = deepcopy(self.tool_list)
        self.set_available_tools(self.tool_list.keys())

    def set_available_tools(self, available_tool_list):
        # TODO @wenmeng.zwm refine tool init
        for t in available_tool_list:
            if t not in self.tool_list:
                raise ValueError(
                    f'Unsupported tools found:{t}, please check, valid ones: {self.tool_list.keys()}'
                )

        self.available_tool_list = {
            k: self.tool_list[k]
            for k in available_tool_list
        }

    def retrieve_tools(self, query: str) -> List[str]:
        """retrieve tools given query

        Args:
            query (str): query

        """
        if self.tool_retrieval:
            retrieve_tools = self.tool_retrieval.retrieve(query)
            self.set_available_tools(available_tool_list=retrieve_tools.keys())
        return self.available_tool_list.values()

    def get_knowledge(self, query: str) -> List[str]:
        """retrieve knowledge given query

        Args:
            query (str): query

        """
        return self.knowledge_retrieval.retrieve(
            query) if self.knowledge_retrieval else []

    def run(self,
            task: str,
            remote: bool = False,
            print_info: bool = False,
            append_files: list = []) -> List[Dict]:
        """ use llm and tools to execute task given by user

        Args:
            task (str): concrete task
            remote (bool, optional): whether to execute tool in remote mode. Defaults to False.
            print_info (bool, optional): whether to print prompt info. Defaults to False.

        Returns:
            List[Dict]: execute result. One task may need to interact with llm multiple times,
            so a list of dict is returned. Each dict contains the result of one interaction.
        """

        # retrieve tools
        tool_list = self.retrieve_tools(task)
        knowledge_list = self.get_knowledge(task)

        self.prompt_generator.init_prompt(
            task, tool_list, knowledge_list, append_files=append_files)
        function_list = self.prompt_generator.get_function_list(tool_list)

        llm_result, exec_result = '', ''

        idx = 0
        final_res = []

        while True:
            idx += 1

            # generate prompt and call llm
            llm_artifacts = self.prompt_generator.generate(
                llm_result, exec_result)
            try:
                llm_result = self.llm.generate(llm_artifacts, function_list)
            except RuntimeError as e:
                return [{'exec_result': str(e)}]

            if print_info:
                print(f'|LLM inputs in round {idx}: {llm_artifacts}')

            # parse and get tool name and arguments
            try:
                action, action_args = self.output_parser.parse_response(
                    llm_result)
            except ValueError as e:
                return [{'exec_result': f'{e}'}]

            if action is None:
                # in chat mode, the final result of last instructions should be updated to prompt history
                _ = self.prompt_generator.generate(llm_result, '')

                # for summarize
                display(llm_result, {}, idx, self.agent_type)
                return final_res

            if action in self.available_tool_list:
                action_args = self.parse_action_args(action_args)
                tool = self.tool_list[action]

                # TODO @wenmeng.zwm remove this hack logic for image generation
                if action == 'image_gen' and self.seed:
                    action_args['seed'] = self.seed
                try:
                    exec_result = tool(**action_args, remote=remote)
                    if print_info:
                        print(f'|exec_result: {exec_result}')

                    # parse exec result and store result to agent state
                    final_res.append(exec_result)
                    self.parse_exec_result(exec_result)
                except Exception as e:
                    exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
                    return [{'exec_result': exec_result}]
            else:
                exec_result = f"Unknown action: '{action}'. "
                return [{'exec_result': exec_result}]

            # display result
            display(llm_result, exec_result, idx, self.agent_type)

    def stream_run(self,
                   task: str,
                   remote: bool = True,
                   print_info: bool = False,
                   append_files: list = []) -> Dict:
        """this is a stream version of run, which can be used in scenario like gradio.
        It will yield the result of each interaction, so that the caller can display the result

        Args:
            task (str): concrete task
            remote (bool, optional): whether to execute tool in remote mode. Defaults to True.
            print_info (bool, optional): whether to print prompt info. Defaults to False.
            files that individually used in each run, no need to record to global state

        Yields:
            Iterator[Dict]: iterator of llm response and tool execution result
        """

        # retrieve tools
        tool_list = self.retrieve_tools(task)
        knowledge_list = self.get_knowledge(task)

        self.prompt_generator.init_prompt(
            task,
            tool_list,
            knowledge_list,
            append_files=append_files,
        )
        function_list = self.prompt_generator.get_function_list(tool_list)

        llm_result, exec_result = '', ''

        idx = 0

        while True:
            idx += 1
            llm_artifacts = self.prompt_generator.generate(
                llm_result, exec_result)
            if print_info:
                print(f'|LLM inputs in round {idx}:\n{llm_artifacts}')

            llm_result = ''
            try:
                for s in self.llm.stream_generate(llm_artifacts,
                                                  function_list):
                    llm_result += s
                    yield {'llm_text': s}
            except RuntimeError:
                s = self.llm.generate(llm_artifacts)
                llm_result += s
                yield {'llm_text': s}
            except Exception as e:
                yield {'llm_text': str(e)}

            # parse and get tool name and arguments
            try:
                action, action_args = self.output_parser.parse_response(
                    llm_result)
            except ValueError as e:
                yield {'exec_result': f'{e}'}
                return

            if action is None:
                # in chat mode, the final result of last instructions should be updated to prompt history
                _ = self.prompt_generator.generate(llm_result, '')
                yield {'is_final': True}
                return

            if action in self.available_tool_list:
                # yield observation to as end of action input symbol asap
                yield {'llm_text': 'Observation: '}
                action_args = self.parse_action_args(action_args)
                tool = self.tool_list[action]

                # TODO @wenmeng.zwm remove this hack logic for image generation
                if action == 'image_gen' and self.seed:
                    action_args['seed'] = self.seed
                try:
                    exec_result = tool(**action_args, remote=remote)
                    yield {'exec_result': exec_result}

                    # parse exec result and update state
                    self.parse_exec_result(exec_result)
                except Exception as e:
                    exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
                    yield {'exec_result': exec_result}
                    self.prompt_generator.reset()
                    return
            else:
                exec_result = f"Unknown action: '{action}'. "
                yield {'exec_result': exec_result}
                self.prompt_generator.reset()
                return

    def reset(self):
        """
        clear history and agent state
        """
        self.prompt_generator.reset()
        self.agent_state = {}

    def parse_action_args(self, action_args):
        """
        replace action_args in str to Image/Video/Audio Wrapper, so that tool can handle them
        """
        parsed_action_args = {}
        for name, arg in action_args.items():
            try:
                true_arg = self.agent_state.get(arg, arg)
            except Exception as e:
                print(f'Error when parsing action args: {e}, using fall back')
                true_arg = arg
            parsed_action_args[name] = true_arg
        return parsed_action_args

    def parse_exec_result(self, exec_result, *args, **kwargs):
        """
        update exec result to agent state.
        key is the str representation of the result.
        """
        for k, v in exec_result.items():
            self.agent_state[str(v)] = v