Highthoughts commited on
Commit
a6a9bfa
·
1 Parent(s): bc74a8c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/dependabot.yml +0 -12
  2. .pre-commit-config.yaml +6 -8
  3. app.py +2 -2
  4. examples/agent_api_web_demo.py +4 -2
  5. lagent.egg-info/PKG-INFO +608 -0
  6. lagent.egg-info/SOURCES.txt +71 -0
  7. lagent.egg-info/dependency_links.txt +1 -0
  8. lagent.egg-info/requires.txt +59 -0
  9. lagent.egg-info/top_level.txt +1 -0
  10. lagent/__pycache__/__init__.cpython-310.pyc +0 -0
  11. lagent/__pycache__/schema.cpython-310.pyc +0 -0
  12. lagent/__pycache__/version.cpython-310.pyc +0 -0
  13. lagent/actions/__init__.py +20 -31
  14. lagent/actions/__pycache__/__init__.cpython-310.pyc +0 -0
  15. lagent/actions/__pycache__/action_executor.cpython-310.pyc +0 -0
  16. lagent/actions/__pycache__/arxiv_search.cpython-310.pyc +0 -0
  17. lagent/actions/__pycache__/base_action.cpython-310.pyc +0 -0
  18. lagent/actions/__pycache__/bing_map.cpython-310.pyc +0 -0
  19. lagent/actions/__pycache__/builtin_actions.cpython-310.pyc +0 -0
  20. lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc +0 -0
  21. lagent/actions/__pycache__/google_search.cpython-310.pyc +0 -0
  22. lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc +0 -0
  23. lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc +0 -0
  24. lagent/actions/__pycache__/ipython_manager.cpython-310.pyc +0 -0
  25. lagent/actions/__pycache__/parser.cpython-310.pyc +0 -0
  26. lagent/actions/__pycache__/ppt.cpython-310.pyc +0 -0
  27. lagent/actions/__pycache__/python_interpreter.cpython-310.pyc +0 -0
  28. lagent/actions/__pycache__/weather_query.cpython-310.pyc +0 -0
  29. lagent/actions/__pycache__/web_browser.cpython-310.pyc +0 -0
  30. lagent/actions/base_action.py +55 -42
  31. lagent/actions/weather_query.py +71 -0
  32. lagent/actions/web_browser.py +283 -232
  33. lagent/agents/__init__.py +4 -28
  34. lagent/agents/__pycache__/__init__.cpython-310.pyc +0 -0
  35. lagent/agents/__pycache__/agent.cpython-310.pyc +0 -0
  36. lagent/agents/__pycache__/react.cpython-310.pyc +0 -0
  37. lagent/agents/__pycache__/stream.cpython-310.pyc +0 -0
  38. lagent/agents/agent.py +117 -187
  39. lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc +0 -0
  40. lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc +0 -0
  41. lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc +0 -0
  42. lagent/agents/react.py +73 -76
  43. lagent/agents/stream.py +94 -65
  44. lagent/distributed/http_serve/api_server.py +8 -16
  45. lagent/hooks/__pycache__/__init__.cpython-310.pyc +0 -0
  46. lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc +0 -0
  47. lagent/hooks/__pycache__/hook.cpython-310.pyc +0 -0
  48. lagent/hooks/__pycache__/logger.cpython-310.pyc +0 -0
  49. lagent/hooks/logger.py +10 -5
  50. lagent/llms/__init__.py +3 -11
.github/dependabot.yml DELETED
@@ -1,12 +0,0 @@
1
- # To get started with Dependabot version updates, you'll need to specify which
2
- # package ecosystems to update and where the package manifests are located.
3
- # Please see the documentation for all configuration options:
4
- # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
5
-
6
- version: 2
7
- updates:
8
- - package-ecosystem: "" # See documentation for possible values
9
- directory: "/" # Location of package manifests
10
- schedule:
11
- interval: "weekly"
12
-
 
 
 
 
 
 
 
 
 
 
 
 
 
.pre-commit-config.yaml CHANGED
@@ -1,22 +1,20 @@
1
  exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
2
  repos:
3
  - repo: https://github.com/PyCQA/flake8
4
- rev: 7.1.1
5
  hooks:
6
  - id: flake8
7
  - repo: https://github.com/PyCQA/isort
8
  rev: 5.13.2
9
  hooks:
10
  - id: isort
11
- args: ["--profile", "black", "--filter-files", "--line-width", "119"]
12
  - repo: https://github.com/psf/black
13
- rev: 24.10.0
14
  hooks:
15
  - id: black
16
  args: ["--line-length", "119", "--skip-string-normalization"]
17
-
18
  - repo: https://github.com/pre-commit/pre-commit-hooks
19
- rev: v5.0.0
20
  hooks:
21
  - id: trailing-whitespace
22
  - id: check-yaml
@@ -29,7 +27,7 @@ repos:
29
  - id: mixed-line-ending
30
  args: ["--fix=lf"]
31
  - repo: https://github.com/executablebooks/mdformat
32
- rev: 0.7.21
33
  hooks:
34
  - id: mdformat
35
  args: ["--number"]
@@ -38,11 +36,11 @@ repos:
38
  - mdformat_frontmatter
39
  - linkify-it-py
40
  - repo: https://github.com/codespell-project/codespell
41
- rev: v2.3.0
42
  hooks:
43
  - id: codespell
44
  - repo: https://github.com/asottile/pyupgrade
45
- rev: v3.19.1
46
  hooks:
47
  - id: pyupgrade
48
  args: ["--py36-plus"]
 
1
  exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
2
  repos:
3
  - repo: https://github.com/PyCQA/flake8
4
+ rev: 7.0.0
5
  hooks:
6
  - id: flake8
7
  - repo: https://github.com/PyCQA/isort
8
  rev: 5.13.2
9
  hooks:
10
  - id: isort
 
11
  - repo: https://github.com/psf/black
12
+ rev: 22.8.0
13
  hooks:
14
  - id: black
15
  args: ["--line-length", "119", "--skip-string-normalization"]
 
16
  - repo: https://github.com/pre-commit/pre-commit-hooks
17
+ rev: v4.5.0
18
  hooks:
19
  - id: trailing-whitespace
20
  - id: check-yaml
 
27
  - id: mixed-line-ending
28
  args: ["--fix=lf"]
29
  - repo: https://github.com/executablebooks/mdformat
30
+ rev: 0.7.17
31
  hooks:
32
  - id: mdformat
33
  args: ["--number"]
 
36
  - mdformat_frontmatter
37
  - linkify-it-py
38
  - repo: https://github.com/codespell-project/codespell
39
+ rev: v2.2.6
40
  hooks:
41
  - id: codespell
42
  - repo: https://github.com/asottile/pyupgrade
43
+ rev: v3.15.0
44
  hooks:
45
  - id: pyupgrade
46
  args: ["--py36-plus"]
app.py CHANGED
@@ -2,8 +2,8 @@
2
  Author: Highthoughts cht7613@gmail.com
3
  Date: 2025-01-30 11:02:01
4
  LastEditors: Highthoughts cht7613@gmail.com
5
- LastEditTime: 2025-01-30 11:02:13
6
- FilePath: \lagent\app.py
7
  Description:
8
 
9
  Copyright (c) 2025 by Cuihaitao, All Rights Reserved.
 
2
  Author: Highthoughts cht7613@gmail.com
3
  Date: 2025-01-30 11:02:01
4
  LastEditors: Highthoughts cht7613@gmail.com
5
+ LastEditTime: 2025-01-30 11:41:16
6
+ FilePath: \AgentTest\app.py
7
  Description:
8
 
9
  Copyright (c) 2025 by Cuihaitao, All Rights Reserved.
examples/agent_api_web_demo.py CHANGED
@@ -2,7 +2,8 @@ import copy
2
  import os
3
  from typing import List
4
  import streamlit as st
5
- from lagent.actions import ArxivSearch
 
6
  from lagent.prompts.parsers import PluginParser
7
  from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt
8
  from lagent.llms import GPTAPI
@@ -17,6 +18,7 @@ class SessionState:
17
  # 初始化插件列表
18
  action_list = [
19
  ArxivSearch(),
 
20
  ]
21
  st.session_state['plugin_map'] = {action.name: action for action in action_list}
22
  st.session_state['model_map'] = {} # 存储模型实例
@@ -50,7 +52,7 @@ class StreamlitUI:
50
  # page_title='lagent-web',
51
  # page_icon='./docs/imgs/lagent_icon.png'
52
  # )
53
- # st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
54
 
55
  def setup_sidebar(self):
56
  """设置侧边栏,选择模型和插件。"""
 
2
  import os
3
  from typing import List
4
  import streamlit as st
5
+ # from lagent.actions import ArxivSearch
6
+ from lagent.actions import ArxivSearch, WeatherQuery
7
  from lagent.prompts.parsers import PluginParser
8
  from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt
9
  from lagent.llms import GPTAPI
 
18
  # 初始化插件列表
19
  action_list = [
20
  ArxivSearch(),
21
+ WeatherQuery(),
22
  ]
23
  st.session_state['plugin_map'] = {action.name: action for action in action_list}
24
  st.session_state['model_map'] = {} # 存储模型实例
 
52
  # page_title='lagent-web',
53
  # page_icon='./docs/imgs/lagent_icon.png'
54
  # )
55
+ st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
56
 
57
  def setup_sidebar(self):
58
  """设置侧边栏,选择模型和插件。"""
lagent.egg-info/PKG-INFO ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.2
2
+ Name: lagent
3
+ Version: 0.5.0rc1
4
+ Summary: A lightweight framework for building LLM-based agents
5
+ Home-page: https://github.com/InternLM/lagent
6
+ License: Apache 2.0
7
+ Keywords: artificial general intelligence,agent,agi,llm
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: aiohttp
11
+ Requires-Dist: arxiv
12
+ Requires-Dist: asyncache
13
+ Requires-Dist: asyncer
14
+ Requires-Dist: distro
15
+ Requires-Dist: duckduckgo_search==5.3.1b1
16
+ Requires-Dist: filelock
17
+ Requires-Dist: func_timeout
18
+ Requires-Dist: griffe<1.0
19
+ Requires-Dist: json5
20
+ Requires-Dist: jsonschema
21
+ Requires-Dist: jupyter==1.0.0
22
+ Requires-Dist: jupyter_client==8.6.2
23
+ Requires-Dist: jupyter_core==5.7.2
24
+ Requires-Dist: pydantic==2.6.4
25
+ Requires-Dist: requests
26
+ Requires-Dist: termcolor
27
+ Requires-Dist: tiktoken
28
+ Requires-Dist: timeout-decorator
29
+ Requires-Dist: typing-extensions
30
+ Provides-Extra: all
31
+ Requires-Dist: google-search-results; extra == "all"
32
+ Requires-Dist: lmdeploy>=0.2.5; extra == "all"
33
+ Requires-Dist: pillow; extra == "all"
34
+ Requires-Dist: python-pptx; extra == "all"
35
+ Requires-Dist: timeout_decorator; extra == "all"
36
+ Requires-Dist: torch; extra == "all"
37
+ Requires-Dist: transformers<=4.40,>=4.34; extra == "all"
38
+ Requires-Dist: vllm>=0.3.3; extra == "all"
39
+ Requires-Dist: aiohttp; extra == "all"
40
+ Requires-Dist: arxiv; extra == "all"
41
+ Requires-Dist: asyncache; extra == "all"
42
+ Requires-Dist: asyncer; extra == "all"
43
+ Requires-Dist: distro; extra == "all"
44
+ Requires-Dist: duckduckgo_search==5.3.1b1; extra == "all"
45
+ Requires-Dist: filelock; extra == "all"
46
+ Requires-Dist: func_timeout; extra == "all"
47
+ Requires-Dist: griffe<1.0; extra == "all"
48
+ Requires-Dist: json5; extra == "all"
49
+ Requires-Dist: jsonschema; extra == "all"
50
+ Requires-Dist: jupyter==1.0.0; extra == "all"
51
+ Requires-Dist: jupyter_client==8.6.2; extra == "all"
52
+ Requires-Dist: jupyter_core==5.7.2; extra == "all"
53
+ Requires-Dist: pydantic==2.6.4; extra == "all"
54
+ Requires-Dist: requests; extra == "all"
55
+ Requires-Dist: termcolor; extra == "all"
56
+ Requires-Dist: tiktoken; extra == "all"
57
+ Requires-Dist: timeout-decorator; extra == "all"
58
+ Requires-Dist: typing-extensions; extra == "all"
59
+ Provides-Extra: optional
60
+ Requires-Dist: google-search-results; extra == "optional"
61
+ Requires-Dist: lmdeploy>=0.2.5; extra == "optional"
62
+ Requires-Dist: pillow; extra == "optional"
63
+ Requires-Dist: python-pptx; extra == "optional"
64
+ Requires-Dist: timeout_decorator; extra == "optional"
65
+ Requires-Dist: torch; extra == "optional"
66
+ Requires-Dist: transformers<=4.40,>=4.34; extra == "optional"
67
+ Requires-Dist: vllm>=0.3.3; extra == "optional"
68
+ Dynamic: description
69
+ Dynamic: description-content-type
70
+ Dynamic: home-page
71
+ Dynamic: keywords
72
+ Dynamic: license
73
+ Dynamic: provides-extra
74
+ Dynamic: requires-dist
75
+ Dynamic: summary
76
+
77
+ <div id="top"></div>
78
+ <div align="center">
79
+ <img src="docs/imgs/lagent_logo.png" width="450"/>
80
+
81
+ [![docs](https://img.shields.io/badge/docs-latest-blue)](https://lagent.readthedocs.io/en/latest/)
82
+ [![PyPI](https://img.shields.io/pypi/v/lagent)](https://pypi.org/project/lagent)
83
+ [![license](https://img.shields.io/github/license/InternLM/lagent.svg)](https://github.com/InternLM/lagent/tree/main/LICENSE)
84
+ [![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lagent)](https://github.com/InternLM/lagent/issues)
85
+ [![open issues](https://img.shields.io/github/issues-raw/InternLM/lagent)](https://github.com/InternLM/lagent/issues)
86
+ ![Visitors](https://api.visitorbadge.io/api/visitors?path=InternLM%2Flagent%20&countColor=%23263759&style=flat)
87
+ ![GitHub forks](https://img.shields.io/github/forks/InternLM/lagent)
88
+ ![GitHub Repo stars](https://img.shields.io/github/stars/InternLM/lagent)
89
+ ![GitHub contributors](https://img.shields.io/github/contributors/InternLM/lagent)
90
+
91
+ </div>
92
+
93
+ <p align="center">
94
+ 👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">𝕏 (Twitter)</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a>
95
+ </p>
96
+
97
+ ## Installation
98
+
99
+ Install from source:
100
+
101
+ ```bash
102
+ git clone https://github.com/InternLM/lagent.git
103
+ cd lagent
104
+ pip install -e .
105
+ ```
106
+
107
+ ## Usage
108
+
109
+ Lagent is inspired by the design philosophy of PyTorch. We expect that the analogy of neural network layers will make the workflow clearer and more intuitive, so users only need to focus on creating layers and defining message passing between them in a Pythonic way. This is a simple tutorial to get you quickly started with building multi-agent applications.
110
+
111
+ ### Models as Agents
112
+
113
+ Agents use `AgentMessage` for communication.
114
+
115
+ ```python
116
+ from typing import Dict, List
117
+ from lagent.agents import Agent
118
+ from lagent.schema import AgentMessage
119
+ from lagent.llms import VllmModel, INTERNLM2_META
120
+
121
+ llm = VllmModel(
122
+ path='Qwen/Qwen2-7B-Instruct',
123
+ meta_template=INTERNLM2_META,
124
+ tp=1,
125
+ top_k=1,
126
+ temperature=1.0,
127
+ stop_words=['<|im_end|>'],
128
+ max_new_tokens=1024,
129
+ )
130
+ system_prompt = '你的回答只能从“典”、“孝”、“急”三��字中选一个。'
131
+ agent = Agent(llm, system_prompt)
132
+
133
+ user_msg = AgentMessage(sender='user', content='今天天气情况')
134
+ bot_msg = agent(user_msg)
135
+ print(bot_msg)
136
+ ```
137
+
138
+ ```
139
+ content='急' sender='Agent' formatted=None extra_info=None type=None receiver=None stream_state=<AgentStatusCode.END: 0>
140
+ ```
141
+
142
+ ### Memory as State
143
+
144
+ Both input and output messages will be added to the memory of `Agent` in each forward pass. This is performed in `__call__` rather than `forward`. See the following pseudo code
145
+
146
+ ```python
147
+ def __call__(self, *message):
148
+ message = pre_hooks(message)
149
+ add_memory(message)
150
+ message = self.forward(*message)
151
+ add_memory(message)
152
+ message = post_hooks(message)
153
+ return message
154
+ ```
155
+
156
+ Inspect the memory in two ways
157
+
158
+ ```python
159
+ memory: List[AgentMessage] = agent.memory.get_memory()
160
+ print(memory)
161
+ print('-' * 120)
162
+ dumped_memory: Dict[str, List[dict]] = agent.state_dict()
163
+ print(dumped_memory['memory'])
164
+ ```
165
+
166
+ ```
167
+ [AgentMessage(content='今天天气情况', sender='user', formatted=None, extra_info=None, type=None, receiver=None, stream_state=<AgentStatusCode.END: 0>), AgentMessage(content='急', sender='Agent', formatted=None, extra_info=None, type=None, receiver=None, stream_state=<AgentStatusCode.END: 0>)]
168
+ ------------------------------------------------------------------------------------------------------------------------
169
+ [{'content': '今天天气情况', 'sender': 'user', 'formatted': None, 'extra_info': None, 'type': None, 'receiver': None, 'stream_state': <AgentStatusCode.END: 0>}, {'content': '急', 'sender': 'Agent', 'formatted': None, 'extra_info': None, 'type': None, 'receiver': None, 'stream_state': <AgentStatusCode.END: 0>}]
170
+ ```
171
+
172
+ Clear the memory of this session(`session_id=0` by default):
173
+
174
+ ```python
175
+ agent.memory.reset()
176
+ ```
177
+
178
+ ### Custom Message Aggregation
179
+
180
+ `DefaultAggregator` is called under the hood to assemble and convert `AgentMessage` to OpenAI message format.
181
+
182
+ ```python
183
+ def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]:
184
+ formatted_messages = self.aggregator.aggregate(
185
+ self.memory.get(session_id),
186
+ self.name,
187
+ self.output_format,
188
+ self.template,
189
+ )
190
+ llm_response = self.llm.chat(formatted_messages, **kwargs)
191
+ ...
192
+ ```
193
+
194
+ Implement a simple aggregator that can receive few-shots
195
+
196
+ ```python
197
+ from typing import List, Union
198
+ from lagent.memory import Memory
199
+ from lagent.prompts import StrParser
200
+ from lagent.agents.aggregator import DefaultAggregator
201
+
202
+ class FewshotAggregator(DefaultAggregator):
203
+ def __init__(self, few_shot: List[dict] = None):
204
+ self.few_shot = few_shot or []
205
+
206
+ def aggregate(self,
207
+ messages: Memory,
208
+ name: str,
209
+ parser: StrParser = None,
210
+ system_instruction: Union[str, dict, List[dict]] = None) -> List[dict]:
211
+ _message = []
212
+ if system_instruction:
213
+ _message.extend(
214
+ self.aggregate_system_intruction(system_instruction))
215
+ _message.extend(self.few_shot)
216
+ messages = messages.get_memory()
217
+ for message in messages:
218
+ if message.sender == name:
219
+ _message.append(
220
+ dict(role='assistant', content=str(message.content)))
221
+ else:
222
+ user_message = message.content
223
+ if len(_message) > 0 and _message[-1]['role'] == 'user':
224
+ _message[-1]['content'] += user_message
225
+ else:
226
+ _message.append(dict(role='user', content=user_message))
227
+ return _message
228
+
229
+ agent = Agent(
230
+ llm,
231
+ aggregator=FewshotAggregator(
232
+ [
233
+ {"role": "user", "content": "今天天气"},
234
+ {"role": "assistant", "content": "【晴】"},
235
+ ]
236
+ )
237
+ )
238
+ user_msg = AgentMessage(sender='user', content='昨天天气')
239
+ bot_msg = agent(user_msg)
240
+ print(bot_msg)
241
+ ```
242
+
243
+ ```
244
+ content='【多云转晴,夜间有轻微降温】' sender='Agent' formatted=None extra_info=None type=None receiver=None stream_state=<AgentStatusCode.END: 0>
245
+ ```
246
+
247
+ ### Flexible Response Formatting
248
+
249
+ In `AgentMessage`, `formatted` is reserved to store information parsed by `output_format` from the model output.
250
+
251
+ ```python
252
+ def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]:
253
+ ...
254
+ llm_response = self.llm.chat(formatted_messages, **kwargs)
255
+ if self.output_format:
256
+ formatted_messages = self.output_format.parse_response(llm_response)
257
+ return AgentMessage(
258
+ sender=self.name,
259
+ content=llm_response,
260
+ formatted=formatted_messages,
261
+ )
262
+ ...
263
+ ```
264
+
265
+ Use a tool parser as follows
266
+
267
+ ````python
268
+ from lagent.prompts.parsers import ToolParser
269
+
270
+ system_prompt = "逐步分析并编写Python代码解决以下问题。"
271
+ parser = ToolParser(tool_type='code interpreter', begin='```python\n', end='\n```\n')
272
+ llm.gen_params['stop_words'].append('\n```\n')
273
+ agent = Agent(llm, system_prompt, output_format=parser)
274
+
275
+ user_msg = AgentMessage(
276
+ sender='user',
277
+ content='Marie is thinking of a multiple of 63, while Jay is thinking of a '
278
+ 'factor of 63. They happen to be thinking of the same number. There are '
279
+ 'two possibilities for the number that each of them is thinking of, one '
280
+ 'positive and one negative. Find the product of these two numbers.')
281
+ bot_msg = agent(user_msg)
282
+ print(bot_msg.model_dump_json(indent=4))
283
+ ````
284
+
285
+ ````
286
+ {
287
+ "content": "首先,我们需要找出63的所有正因数和负因数。63的正因数可以通过分解63的质因数来找出,即\\(63 = 3^2 \\times 7\\)。因此,63的正因数包括1, 3, 7, 9, 21, 和 63。对于负因数,我们只需将上述正因数乘以-1。\n\n接下来,我们需要找出与63的正因数相乘的结果为63的数,以及与63的负因数相乘的结果为63的数。这可以通过将63除以每个正因数和负因数来实现。\n\n最后,我们将找到的两个数相乘得到最终答案。\n\n下面是Python代码实现:\n\n```python\ndef find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)",
288
+ "sender": "Agent",
289
+ "formatted": {
290
+ "tool_type": "code interpreter",
291
+ "thought": "首先,我们需要找出63的所有正因数和负因数。63的正因数可以通过分解63的质因数来找出,即\\(63 = 3^2 \\times 7\\)。因此,63的正因数包括1, 3, 7, 9, 21, 和 63。对于负因数,我们只需将上述正因数乘以-1。\n\n接下来,我们需要找出与63的正因数相乘的结果为63的数,以及与63的负因数相乘的结果为63的数。这可以通过将63除以每个正因数和负因数来实现。\n\n最后,我们将找到的两个数相乘得到最终答案。\n\n下面是Python代码实现:\n\n",
292
+ "action": "def find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)",
293
+ "status": 1
294
+ },
295
+ "extra_info": null,
296
+ "type": null,
297
+ "receiver": null,
298
+ "stream_state": 0
299
+ }
300
+ ````
301
+
302
+ ### Consistency of Tool Calling
303
+
304
+ `ActionExecutor` uses the same communication data structure as `Agent`, but requires the content of input `AgentMessage` to be a dict containing:
305
+
306
+ - `name`: tool name, e.g. `'IPythonInterpreter'`, `'WebBrowser.search'`.
307
+ - `parameters`: keyword arguments of the tool API, e.g. `{'command': 'import math;math.sqrt(2)'}`, `{'query': ['recent progress in AI']}`.
308
+
309
+ You can register custom hooks for message conversion.
310
+
311
+ ```python
312
+ from lagent.hooks import Hook
313
+ from lagent.schema import ActionReturn, ActionStatusCode, AgentMessage
314
+ from lagent.actions import ActionExecutor, IPythonInteractive
315
+
316
+ class CodeProcessor(Hook):
317
+ def before_action(self, executor, message, session_id):
318
+ message = message.copy(deep=True)
319
+ message.content = dict(
320
+ name='IPythonInteractive', parameters={'command': message.formatted['action']}
321
+ )
322
+ return message
323
+
324
+ def after_action(self, executor, message, session_id):
325
+ action_return = message.content
326
+ if isinstance(action_return, ActionReturn):
327
+ if action_return.state == ActionStatusCode.SUCCESS:
328
+ response = action_return.format_result()
329
+ else:
330
+ response = action_return.errmsg
331
+ else:
332
+ response = action_return
333
+ message.content = response
334
+ return message
335
+
336
+ executor = ActionExecutor(actions=[IPythonInteractive()], hooks=[CodeProcessor()])
337
+ bot_msg = AgentMessage(
338
+ sender='Agent',
339
+ content='首先,我们需要...',
340
+ formatted={
341
+ 'tool_type': 'code interpreter',
342
+ 'thought': '首先,我们需要...',
343
+ 'action': 'def find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)',
344
+ 'status': 1
345
+ })
346
+ executor_msg = executor(bot_msg)
347
+ print(executor_msg)
348
+ ```
349
+
350
+ ```
351
+ content='3969.0' sender='ActionExecutor' formatted=None extra_info=None type=None receiver=None stream_state=<AgentStatusCode.END: 0>
352
+ ```
353
+
354
+ **For convenience, Lagent provides `InternLMActionProcessor` which is adapted to messages formatted by `ToolParser` as mentioned above.**
355
+
356
+ ### Dual Interfaces
357
+
358
+ Lagent adopts dual interface design, where almost every component(LLMs, actions, action executors...) has the corresponding asynchronous variant by prefixing its identifier with 'Async'. It is recommended to use synchronous agents for debugging and asynchronous ones for large-scale inference to make the most of idle CPU and GPU resources.
359
+
360
+ However, make sure the internal consistency of agents, i.e. asynchronous agents should be equipped with asynchronous LLMs and asynchronous action executors that drive asynchronous tools.
361
+
362
+ ```python
363
+ from lagent.llms import VllmModel, AsyncVllmModel, LMDeployPipeline, AsyncLMDeployPipeline
364
+ from lagent.actions import ActionExecutor, AsyncActionExecutor, WebBrowser, AsyncWebBrowser
365
+ from lagent.agents import Agent, AsyncAgent, AgentForInternLM, AsyncAgentForInternLM
366
+ ```
367
+
368
+ ______________________________________________________________________
369
+
370
+ ## Practice
371
+
372
+ - **Try to implement `forward` instead of `__call__` of subclasses unless necessary.**
373
+ - **Always include the `session_id` argument explicitly, which is designed for isolation of memory, LLM requests and tool invocation(e.g. maintain multiple independent IPython environments) in concurrency.**
374
+
375
+ ### Single Agent
376
+
377
+ Math agents that solve problems by programming
378
+
379
+ ````python
380
+ from lagent.agents.aggregator import InternLMToolAggregator
381
+
382
+ class Coder(Agent):
383
+ def __init__(self, model_path, system_prompt, max_turn=3):
384
+ super().__init__()
385
+ llm = VllmModel(
386
+ path=model_path,
387
+ meta_template=INTERNLM2_META,
388
+ tp=1,
389
+ top_k=1,
390
+ temperature=1.0,
391
+ stop_words=['\n```\n', '<|im_end|>'],
392
+ max_new_tokens=1024,
393
+ )
394
+ self.agent = Agent(
395
+ llm,
396
+ system_prompt,
397
+ output_format=ToolParser(
398
+ tool_type='code interpreter', begin='```python\n', end='\n```\n'
399
+ ),
400
+ # `InternLMToolAggregator` is adapted to `ToolParser` for aggregating
401
+ # messages with tool invocations and execution results
402
+ aggregator=InternLMToolAggregator(),
403
+ )
404
+ self.executor = ActionExecutor([IPythonInteractive()], hooks=[CodeProcessor()])
405
+ self.max_turn = max_turn
406
+
407
+ def forward(self, message: AgentMessage, session_id=0) -> AgentMessage:
408
+ for _ in range(self.max_turn):
409
+ message = self.agent(message, session_id=session_id)
410
+ if message.formatted['tool_type'] is None:
411
+ return message
412
+ message = self.executor(message, session_id=session_id)
413
+ return message
414
+
415
+ coder = Coder('Qwen/Qwen2-7B-Instruct', 'Solve the problem step by step with assistance of Python code')
416
+ query = AgentMessage(
417
+ sender='user',
418
+ content='Find the projection of $\\mathbf{a}$ onto $\\mathbf{b} = '
419
+ '\\begin{pmatrix} 1 \\\\ -3 \\end{pmatrix}$ if $\\mathbf{a} \\cdot \\mathbf{b} = 2.$'
420
+ )
421
+ answer = coder(query)
422
+ print(answer.content)
423
+ print('-' * 120)
424
+ for msg in coder.state_dict()['agent.memory']:
425
+ print('*' * 80)
426
+ print(f'{msg["sender"]}:\n\n{msg["content"]}')
427
+ ````
428
+
429
+ ### Multiple Agents
430
+
431
+ Asynchronous blogging agents that improve writing quality by self-refinement ([original AutoGen example](https://microsoft.github.io/autogen/0.2/docs/topics/prompting-and-reasoning/reflection/))
432
+
433
+ ```python
434
+ import asyncio
435
+ import os
436
+ from lagent.llms import AsyncGPTAPI
437
+ from lagent.agents import AsyncAgent
438
+ os.environ['OPENAI_API_KEY'] = 'YOUR_API_KEY'
439
+
440
+ class PrefixedMessageHook(Hook):
441
+ def __init__(self, prefix: str, senders: list = None):
442
+ self.prefix = prefix
443
+ self.senders = senders or []
444
+
445
+ def before_agent(self, agent, messages, session_id):
446
+ for message in messages:
447
+ if message.sender in self.senders:
448
+ message.content = self.prefix + message.content
449
+
450
+ class AsyncBlogger(AsyncAgent):
451
+ def __init__(self, model_path, writer_prompt, critic_prompt, critic_prefix='', max_turn=3):
452
+ super().__init__()
453
+ llm = AsyncGPTAPI(model_type=model_path, retry=5, max_new_tokens=2048)
454
+ self.writer = AsyncAgent(llm, writer_prompt, name='writer')
455
+ self.critic = AsyncAgent(
456
+ llm, critic_prompt, name='critic', hooks=[PrefixedMessageHook(critic_prefix, ['writer'])]
457
+ )
458
+ self.max_turn = max_turn
459
+
460
+ async def forward(self, message: AgentMessage, session_id=0) -> AgentMessage:
461
+ for _ in range(self.max_turn):
462
+ message = await self.writer(message, session_id=session_id)
463
+ message = await self.critic(message, session_id=session_id)
464
+ return await self.writer(message, session_id=session_id)
465
+
466
+ blogger = AsyncBlogger(
467
+ 'gpt-4o-2024-05-13',
468
+ writer_prompt="You are an writing assistant tasked to write engaging blogpost. You try to generate the best blogpost possible for the user's request. "
469
+ "If the user provides critique, then respond with a revised version of your previous attempts",
470
+ critic_prompt="Generate critique and recommendations on the writing. Provide detailed recommendations, including requests for length, depth, style, etc..",
471
+ critic_prefix='Reflect and provide critique on the following writing. \n\n',
472
+ )
473
+ user_prompt = (
474
+ "Write an engaging blogpost on the recent updates in {topic}. "
475
+ "The blogpost should be engaging and understandable for general audience. "
476
+ "Should have more than 3 paragraphes but no longer than 1000 words.")
477
+ bot_msgs = asyncio.get_event_loop().run_until_complete(
478
+ asyncio.gather(
479
+ *[
480
+ blogger(AgentMessage(sender='user', content=user_prompt.format(topic=topic)), session_id=i)
481
+ for i, topic in enumerate(['AI', 'Biotechnology', 'New Energy', 'Video Games', 'Pop Music'])
482
+ ]
483
+ )
484
+ )
485
+ print(bot_msgs[0].content)
486
+ print('-' * 120)
487
+ for msg in blogger.state_dict(session_id=0)['writer.memory']:
488
+ print('*' * 80)
489
+ print(f'{msg["sender"]}:\n\n{msg["content"]}')
490
+ print('-' * 120)
491
+ for msg in blogger.state_dict(session_id=0)['critic.memory']:
492
+ print('*' * 80)
493
+ print(f'{msg["sender"]}:\n\n{msg["content"]}')
494
+ ```
495
+
496
+ A multi-agent workflow that performs information retrieval, data collection and chart plotting ([original LangGraph example](https://vijaykumarkartha.medium.com/multiple-ai-agents-creating-multi-agent-workflows-using-langgraph-and-langchain-0587406ec4e6))
497
+
498
+ <div align="center">
499
+ <img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*ffzadZCKXJT7n4JaRVFvcQ.jpeg" width="850" />
500
+ </div>
501
+
502
+ ````python
503
+ import json
504
+ from lagent.actions import IPythonInterpreter, WebBrowser, ActionExecutor
505
+ from lagent.agents.stream import get_plugin_prompt
506
+ from lagent.llms import GPTAPI
507
+ from lagent.hooks import InternLMActionProcessor
508
+
509
+ TOOL_TEMPLATE = (
510
+ "You are a helpful AI assistant, collaborating with other assistants. Use the provided tools to progress"
511
+ " towards answering the question. If you are unable to fully answer, that's OK, another assistant with"
512
+ " different tools will help where you left off. Execute what you can to make progress. If you or any of"
513
+ " the other assistants have the final answer or deliverable, prefix your response with {finish_pattern}"
514
+ " so the team knows to stop. You have access to the following tools:\n{tool_description}\nPlease provide"
515
+ " your thought process when you need to use a tool, followed by the call statement in this format:"
516
+ "\n{invocation_format}\\\\n**{system_prompt}**"
517
+ )
518
+
519
+ class DataVisualizer(Agent):
520
+ def __init__(self, model_path, research_prompt, chart_prompt, finish_pattern="Final Answer", max_turn=10):
521
+ super().__init__()
522
+ llm = GPTAPI(model_path, key='YOUR_OPENAI_API_KEY', retry=5, max_new_tokens=1024, stop_words=["```\n"])
523
+ interpreter, browser = IPythonInterpreter(), WebBrowser("BingSearch", api_key="YOUR_BING_API_KEY")
524
+ self.researcher = Agent(
525
+ llm,
526
+ TOOL_TEMPLATE.format(
527
+ finish_pattern=finish_pattern,
528
+ tool_description=get_plugin_prompt(browser),
529
+ invocation_format='```json\n{"name": {{tool name}}, "parameters": {{keyword arguments}}}\n```\n',
530
+ system_prompt=research_prompt,
531
+ ),
532
+ output_format=ToolParser(
533
+ "browser",
534
+ begin="```json\n",
535
+ end="\n```\n",
536
+ validate=lambda x: json.loads(x.rstrip('`')),
537
+ ),
538
+ aggregator=InternLMToolAggregator(),
539
+ name="researcher",
540
+ )
541
+ self.charter = Agent(
542
+ llm,
543
+ TOOL_TEMPLATE.format(
544
+ finish_pattern=finish_pattern,
545
+ tool_description=interpreter.name,
546
+ invocation_format='```python\n{{code}}\n```\n',
547
+ system_prompt=chart_prompt,
548
+ ),
549
+ output_format=ToolParser(
550
+ "interpreter",
551
+ begin="```python\n",
552
+ end="\n```\n",
553
+ validate=lambda x: x.rstrip('`'),
554
+ ),
555
+ aggregator=InternLMToolAggregator(),
556
+ name="charter",
557
+ )
558
+ self.executor = ActionExecutor([interpreter, browser], hooks=[InternLMActionProcessor()])
559
+ self.finish_pattern = finish_pattern
560
+ self.max_turn = max_turn
561
+
562
+ def forward(self, message, session_id=0):
563
+ for _ in range(self.max_turn):
564
+ message = self.researcher(message, session_id=session_id, stop_words=["```\n", "```python"]) # override llm stop words
565
+ while message.formatted["tool_type"]:
566
+ message = self.executor(message, session_id=session_id)
567
+ message = self.researcher(message, session_id=session_id, stop_words=["```\n", "```python"])
568
+ if self.finish_pattern in message.content:
569
+ return message
570
+ message = self.charter(message)
571
+ while message.formatted["tool_type"]:
572
+ message = self.executor(message, session_id=session_id)
573
+ message = self.charter(message, session_id=session_id)
574
+ if self.finish_pattern in message.content:
575
+ return message
576
+ return message
577
+
578
+ visualizer = DataVisualizer(
579
+ "gpt-4o-2024-05-13",
580
+ research_prompt="You should provide accurate data for the chart generator to use.",
581
+ chart_prompt="Any charts you display will be visible by the user.",
582
+ )
583
+ user_msg = AgentMessage(
584
+ sender='user',
585
+ content="Fetch the China's GDP over the past 5 years, then draw a line graph of it. Once you code it up, finish.")
586
+ bot_msg = visualizer(user_msg)
587
+ print(bot_msg.content)
588
+ json.dump(visualizer.state_dict(), open('visualizer.json', 'w'), ensure_ascii=False, indent=4)
589
+ ````
590
+
591
+ ## Citation
592
+
593
+ If you find this project useful in your research, please consider cite:
594
+
595
+ ```latex
596
+ @misc{lagent2023,
597
+ title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents},
598
+ author={Lagent Developer Team},
599
+ howpublished = {\url{https://github.com/InternLM/lagent}},
600
+ year={2023}
601
+ }
602
+ ```
603
+
604
+ ## License
605
+
606
+ This project is released under the [Apache 2.0 license](LICENSE).
607
+
608
+ <p align="right"><a href="#top">🔼 Back to top</a></p>
lagent.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ MANIFEST.in
3
+ README.md
4
+ setup.cfg
5
+ setup.py
6
+ lagent/__init__.py
7
+ lagent/schema.py
8
+ lagent/version.py
9
+ lagent.egg-info/PKG-INFO
10
+ lagent.egg-info/SOURCES.txt
11
+ lagent.egg-info/dependency_links.txt
12
+ lagent.egg-info/requires.txt
13
+ lagent.egg-info/top_level.txt
14
+ lagent/actions/__init__.py
15
+ lagent/actions/action_executor.py
16
+ lagent/actions/arxiv_search.py
17
+ lagent/actions/base_action.py
18
+ lagent/actions/bing_map.py
19
+ lagent/actions/builtin_actions.py
20
+ lagent/actions/google_scholar_search.py
21
+ lagent/actions/google_search.py
22
+ lagent/actions/ipython_interactive.py
23
+ lagent/actions/ipython_interpreter.py
24
+ lagent/actions/ipython_manager.py
25
+ lagent/actions/parser.py
26
+ lagent/actions/ppt.py
27
+ lagent/actions/python_interpreter.py
28
+ lagent/actions/web_browser.py
29
+ lagent/agents/__init__.py
30
+ lagent/agents/agent.py
31
+ lagent/agents/react.py
32
+ lagent/agents/stream.py
33
+ lagent/agents/aggregator/__init__.py
34
+ lagent/agents/aggregator/default_aggregator.py
35
+ lagent/agents/aggregator/tool_aggregator.py
36
+ lagent/distributed/__init__.py
37
+ lagent/distributed/http_serve/__init__.py
38
+ lagent/distributed/http_serve/api_server.py
39
+ lagent/distributed/http_serve/app.py
40
+ lagent/distributed/ray_serve/__init__.py
41
+ lagent/distributed/ray_serve/ray_warpper.py
42
+ lagent/hooks/__init__.py
43
+ lagent/hooks/action_preprocessor.py
44
+ lagent/hooks/hook.py
45
+ lagent/hooks/logger.py
46
+ lagent/llms/__init__.py
47
+ lagent/llms/base_api.py
48
+ lagent/llms/base_llm.py
49
+ lagent/llms/huggingface.py
50
+ lagent/llms/lmdeploy_wrapper.py
51
+ lagent/llms/meta_template.py
52
+ lagent/llms/openai.py
53
+ lagent/llms/sensenova.py
54
+ lagent/llms/vllm_wrapper.py
55
+ lagent/memory/__init__.py
56
+ lagent/memory/base_memory.py
57
+ lagent/memory/manager.py
58
+ lagent/prompts/__init__.py
59
+ lagent/prompts/prompt_template.py
60
+ lagent/prompts/parsers/__init__.py
61
+ lagent/prompts/parsers/custom_parser.py
62
+ lagent/prompts/parsers/json_parser.py
63
+ lagent/prompts/parsers/str_parser.py
64
+ lagent/prompts/parsers/tool_parser.py
65
+ lagent/utils/__init__.py
66
+ lagent/utils/gen_key.py
67
+ lagent/utils/package.py
68
+ lagent/utils/util.py
69
+ requirements/docs.txt
70
+ requirements/optional.txt
71
+ requirements/runtime.txt
lagent.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
lagent.egg-info/requires.txt ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp
2
+ arxiv
3
+ asyncache
4
+ asyncer
5
+ distro
6
+ duckduckgo_search==5.3.1b1
7
+ filelock
8
+ func_timeout
9
+ griffe<1.0
10
+ json5
11
+ jsonschema
12
+ jupyter==1.0.0
13
+ jupyter_client==8.6.2
14
+ jupyter_core==5.7.2
15
+ pydantic==2.6.4
16
+ requests
17
+ termcolor
18
+ tiktoken
19
+ timeout-decorator
20
+ typing-extensions
21
+
22
+ [all]
23
+ google-search-results
24
+ lmdeploy>=0.2.5
25
+ pillow
26
+ python-pptx
27
+ timeout_decorator
28
+ torch
29
+ transformers<=4.40,>=4.34
30
+ vllm>=0.3.3
31
+ aiohttp
32
+ arxiv
33
+ asyncache
34
+ asyncer
35
+ distro
36
+ duckduckgo_search==5.3.1b1
37
+ filelock
38
+ func_timeout
39
+ griffe<1.0
40
+ json5
41
+ jsonschema
42
+ jupyter==1.0.0
43
+ jupyter_client==8.6.2
44
+ jupyter_core==5.7.2
45
+ pydantic==2.6.4
46
+ requests
47
+ termcolor
48
+ tiktoken
49
+ typing-extensions
50
+
51
+ [optional]
52
+ google-search-results
53
+ lmdeploy>=0.2.5
54
+ pillow
55
+ python-pptx
56
+ timeout_decorator
57
+ torch
58
+ transformers<=4.40,>=4.34
59
+ vllm>=0.3.3
lagent.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ lagent
lagent/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (231 Bytes). View file
 
lagent/__pycache__/schema.cpython-310.pyc ADDED
Binary file (3.46 kB). View file
 
lagent/__pycache__/version.cpython-310.pyc ADDED
Binary file (744 Bytes). View file
 
lagent/actions/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  from .action_executor import ActionExecutor, AsyncActionExecutor
2
  from .arxiv_search import ArxivSearch, AsyncArxivSearch
3
- from .base_action import AsyncActionMixin, BaseAction, tool_api
4
  from .bing_map import AsyncBINGMap, BINGMap
5
  from .builtin_actions import FinishAction, InvalidAction, NoAction
6
  from .google_scholar_search import AsyncGoogleScholar, GoogleScholar
@@ -14,34 +14,23 @@ from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter
14
  from .web_browser import AsyncWebBrowser, WebBrowser
15
 
16
  __all__ = [
17
- 'BaseAction',
18
- 'ActionExecutor',
19
- 'AsyncActionExecutor',
20
- 'InvalidAction',
21
- 'FinishAction',
22
- 'NoAction',
23
- 'BINGMap',
24
- 'AsyncBINGMap',
25
- 'ArxivSearch',
26
- 'AsyncArxivSearch',
27
- 'GoogleSearch',
28
- 'AsyncGoogleSearch',
29
- 'GoogleScholar',
30
- 'AsyncGoogleScholar',
31
- 'IPythonInterpreter',
32
- 'AsyncIPythonInterpreter',
33
- 'IPythonInteractive',
34
- 'AsyncIPythonInteractive',
35
- 'IPythonInteractiveManager',
36
- 'PythonInterpreter',
37
- 'AsyncPythonInterpreter',
38
- 'PPT',
39
- 'AsyncPPT',
40
- 'WebBrowser',
41
- 'AsyncWebBrowser',
42
- 'BaseParser',
43
- 'JsonParser',
44
- 'TupleParser',
45
- 'tool_api',
46
- 'AsyncActionMixin',
47
  ]
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .action_executor import ActionExecutor, AsyncActionExecutor
2
  from .arxiv_search import ArxivSearch, AsyncArxivSearch
3
+ from .base_action import BaseAction, tool_api
4
  from .bing_map import AsyncBINGMap, BINGMap
5
  from .builtin_actions import FinishAction, InvalidAction, NoAction
6
  from .google_scholar_search import AsyncGoogleScholar, GoogleScholar
 
14
  from .web_browser import AsyncWebBrowser, WebBrowser
15
 
16
  __all__ = [
17
+ 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction',
18
+ 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch',
19
+ 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar',
20
+ 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter',
21
+ 'IPythonInteractive', 'AsyncIPythonInteractive',
22
+ 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter',
23
+ 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser',
24
+ 'JsonParser', 'TupleParser', 'tool_api'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ]
26
+ from .weather_query import WeatherQuery
27
+ __all__ = [
28
+ 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction',
29
+ 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch',
30
+ 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar',
31
+ 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter',
32
+ 'IPythonInteractive', 'AsyncIPythonInteractive',
33
+ 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter',
34
+ 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser',
35
+ 'JsonParser', 'TupleParser', 'tool_api', 'WeatherQuery' # 这里
36
+ ]
lagent/actions/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.63 kB). View file
 
lagent/actions/__pycache__/action_executor.cpython-310.pyc ADDED
Binary file (5.84 kB). View file
 
lagent/actions/__pycache__/arxiv_search.cpython-310.pyc ADDED
Binary file (3.19 kB). View file
 
lagent/actions/__pycache__/base_action.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
lagent/actions/__pycache__/bing_map.cpython-310.pyc ADDED
Binary file (7.79 kB). View file
 
lagent/actions/__pycache__/builtin_actions.cpython-310.pyc ADDED
Binary file (3.89 kB). View file
 
lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
lagent/actions/__pycache__/google_search.cpython-310.pyc ADDED
Binary file (6.93 kB). View file
 
lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc ADDED
Binary file (8.41 kB). View file
 
lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
lagent/actions/__pycache__/ipython_manager.cpython-310.pyc ADDED
Binary file (7.11 kB). View file
 
lagent/actions/__pycache__/parser.cpython-310.pyc ADDED
Binary file (5.48 kB). View file
 
lagent/actions/__pycache__/ppt.cpython-310.pyc ADDED
Binary file (6.81 kB). View file
 
lagent/actions/__pycache__/python_interpreter.cpython-310.pyc ADDED
Binary file (5.38 kB). View file
 
lagent/actions/__pycache__/weather_query.cpython-310.pyc ADDED
Binary file (2.66 kB). View file
 
lagent/actions/__pycache__/web_browser.cpython-310.pyc ADDED
Binary file (28.8 kB). View file
 
lagent/actions/base_action.py CHANGED
@@ -4,7 +4,7 @@ import re
4
  from abc import ABCMeta
5
  from copy import deepcopy
6
  from functools import wraps
7
- from typing import Callable, Iterable, Optional, Type, get_args, get_origin
8
 
9
  try:
10
  from typing import Annotated
@@ -24,15 +24,11 @@ from .parser import BaseParser, JsonParser, ParseError
24
  logging.getLogger('griffe').setLevel(logging.ERROR)
25
 
26
 
27
- def tool_api(
28
- func: Optional[Callable] = None,
29
- *,
30
- explode_return: bool = False,
31
- returns_named_value: bool = False,
32
- include_arguments: Optional[Iterable[str]] = None,
33
- exclude_arguments: Optional[Iterable[str]] = None,
34
- **kwargs,
35
- ):
36
  """Turn functions into tools. It will parse typehints as well as docstrings
37
  to build the tool description and attach it to functions via an attribute
38
  ``api_description``.
@@ -94,16 +90,6 @@ def tool_api(
94
  ``return_data`` field will be added to ``api_description`` only
95
  when ``explode_return`` or ``returns_named_value`` is enabled.
96
  """
97
- if include_arguments is None:
98
- exclude_arguments = exclude_arguments or set()
99
- if isinstance(exclude_arguments, str):
100
- exclude_arguments = {exclude_arguments}
101
- elif not isinstance(exclude_arguments, set):
102
- exclude_arguments = set(exclude_arguments)
103
- if 'self' not in exclude_arguments:
104
- exclude_arguments.add('self')
105
- else:
106
- include_arguments = {include_arguments} if isinstance(include_arguments, str) else set(include_arguments)
107
 
108
  def _detect_type(string):
109
  field_type = 'STRING'
@@ -120,9 +106,10 @@ def tool_api(
120
 
121
  def _explode(desc):
122
  kvs = []
123
- desc = '\nArgs:\n' + '\n'.join(
124
- [' ' + item.lstrip(' -+*#.') for item in desc.split('\n')[1:] if item.strip()]
125
- )
 
126
  docs = Docstring(desc).parse('google')
127
  if not docs:
128
  return kvs
@@ -138,12 +125,13 @@ def tool_api(
138
 
139
  def _parse_tool(function):
140
  # remove rst syntax
141
- docs = Docstring(re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
142
- 'google', returns_named_value=returns_named_value, **kwargs
143
- )
144
  desc = dict(
145
  name=function.__name__,
146
- description=docs[0].value if docs[0].kind is DocstringSectionKind.text else '',
 
147
  parameters=[],
148
  required=[],
149
  )
@@ -167,14 +155,17 @@ def tool_api(
167
 
168
  sig = inspect.signature(function)
169
  for name, param in sig.parameters.items():
170
- if name in exclude_arguments if include_arguments is None else name not in include_arguments:
171
  continue
172
  parameter = dict(
173
- name=param.name, type='STRING', description=args_doc.get(param.name, {}).get('description', '')
174
- )
 
 
175
  annotation = param.annotation
176
  if annotation is inspect.Signature.empty:
177
- parameter['type'] = args_doc.get(param.name, {}).get('type', 'STRING')
 
178
  else:
179
  if get_origin(annotation) is Annotated:
180
  annotation, info = get_args(annotation)
@@ -238,8 +229,9 @@ class ToolMeta(ABCMeta):
238
 
239
  def __new__(mcs, name, base, attrs):
240
  is_toolkit, tool_desc = True, dict(
241
- name=name, description=Docstring(attrs.get('__doc__', '')).parse('google')[0].value
242
- )
 
243
  for key, value in attrs.items():
244
  if callable(value) and hasattr(value, 'api_description'):
245
  api_desc = getattr(value, 'api_description')
@@ -254,7 +246,8 @@ class ToolMeta(ABCMeta):
254
  else:
255
  tool_desc.setdefault('api_list', []).append(api_desc)
256
  if not is_toolkit and 'api_list' in tool_desc:
257
- raise KeyError('`run` and other tool APIs can not be implemented ' 'at the same time')
 
258
  if is_toolkit and 'api_list' not in tool_desc:
259
  is_toolkit = False
260
  if callable(attrs.get('run')):
@@ -353,16 +346,26 @@ class BaseAction(metaclass=ToolMeta):
353
  fallback_args = {'inputs': inputs, 'name': name}
354
  if not hasattr(self, name):
355
  return ActionReturn(
356
- fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR
357
- )
 
 
358
  try:
359
  inputs = self._parser.parse_inputs(inputs, name)
360
  except ParseError as exc:
361
- return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR)
 
 
 
 
362
  try:
363
  outputs = getattr(self, name)(**inputs)
364
  except Exception as exc:
365
- return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR)
 
 
 
 
366
  if isinstance(outputs, ActionReturn):
367
  action_return = outputs
368
  if not action_return.args:
@@ -399,16 +402,26 @@ class AsyncActionMixin:
399
  fallback_args = {'inputs': inputs, 'name': name}
400
  if not hasattr(self, name):
401
  return ActionReturn(
402
- fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR
403
- )
 
 
404
  try:
405
  inputs = self._parser.parse_inputs(inputs, name)
406
  except ParseError as exc:
407
- return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR)
 
 
 
 
408
  try:
409
  outputs = await getattr(self, name)(**inputs)
410
  except Exception as exc:
411
- return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR)
 
 
 
 
412
  if isinstance(outputs, ActionReturn):
413
  action_return = outputs
414
  if not action_return.args:
 
4
  from abc import ABCMeta
5
  from copy import deepcopy
6
  from functools import wraps
7
+ from typing import Callable, Optional, Type, get_args, get_origin
8
 
9
  try:
10
  from typing import Annotated
 
24
  logging.getLogger('griffe').setLevel(logging.ERROR)
25
 
26
 
27
+ def tool_api(func: Optional[Callable] = None,
28
+ *,
29
+ explode_return: bool = False,
30
+ returns_named_value: bool = False,
31
+ **kwargs):
 
 
 
 
32
  """Turn functions into tools. It will parse typehints as well as docstrings
33
  to build the tool description and attach it to functions via an attribute
34
  ``api_description``.
 
90
  ``return_data`` field will be added to ``api_description`` only
91
  when ``explode_return`` or ``returns_named_value`` is enabled.
92
  """
 
 
 
 
 
 
 
 
 
 
93
 
94
  def _detect_type(string):
95
  field_type = 'STRING'
 
106
 
107
  def _explode(desc):
108
  kvs = []
109
+ desc = '\nArgs:\n' + '\n'.join([
110
+ ' ' + item.lstrip(' -+*#.')
111
+ for item in desc.split('\n')[1:] if item.strip()
112
+ ])
113
  docs = Docstring(desc).parse('google')
114
  if not docs:
115
  return kvs
 
125
 
126
  def _parse_tool(function):
127
  # remove rst syntax
128
+ docs = Docstring(
129
+ re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
130
+ 'google', returns_named_value=returns_named_value, **kwargs)
131
  desc = dict(
132
  name=function.__name__,
133
+ description=docs[0].value
134
+ if docs[0].kind is DocstringSectionKind.text else '',
135
  parameters=[],
136
  required=[],
137
  )
 
155
 
156
  sig = inspect.signature(function)
157
  for name, param in sig.parameters.items():
158
+ if name == 'self':
159
  continue
160
  parameter = dict(
161
+ name=param.name,
162
+ type='STRING',
163
+ description=args_doc.get(param.name,
164
+ {}).get('description', ''))
165
  annotation = param.annotation
166
  if annotation is inspect.Signature.empty:
167
+ parameter['type'] = args_doc.get(param.name,
168
+ {}).get('type', 'STRING')
169
  else:
170
  if get_origin(annotation) is Annotated:
171
  annotation, info = get_args(annotation)
 
229
 
230
  def __new__(mcs, name, base, attrs):
231
  is_toolkit, tool_desc = True, dict(
232
+ name=name,
233
+ description=Docstring(attrs.get('__doc__',
234
+ '')).parse('google')[0].value)
235
  for key, value in attrs.items():
236
  if callable(value) and hasattr(value, 'api_description'):
237
  api_desc = getattr(value, 'api_description')
 
246
  else:
247
  tool_desc.setdefault('api_list', []).append(api_desc)
248
  if not is_toolkit and 'api_list' in tool_desc:
249
+ raise KeyError('`run` and other tool APIs can not be implemented '
250
+ 'at the same time')
251
  if is_toolkit and 'api_list' not in tool_desc:
252
  is_toolkit = False
253
  if callable(attrs.get('run')):
 
346
  fallback_args = {'inputs': inputs, 'name': name}
347
  if not hasattr(self, name):
348
  return ActionReturn(
349
+ fallback_args,
350
+ type=self.name,
351
+ errmsg=f'invalid API: {name}',
352
+ state=ActionStatusCode.API_ERROR)
353
  try:
354
  inputs = self._parser.parse_inputs(inputs, name)
355
  except ParseError as exc:
356
+ return ActionReturn(
357
+ fallback_args,
358
+ type=self.name,
359
+ errmsg=exc.err_msg,
360
+ state=ActionStatusCode.ARGS_ERROR)
361
  try:
362
  outputs = getattr(self, name)(**inputs)
363
  except Exception as exc:
364
+ return ActionReturn(
365
+ inputs,
366
+ type=self.name,
367
+ errmsg=str(exc),
368
+ state=ActionStatusCode.API_ERROR)
369
  if isinstance(outputs, ActionReturn):
370
  action_return = outputs
371
  if not action_return.args:
 
402
  fallback_args = {'inputs': inputs, 'name': name}
403
  if not hasattr(self, name):
404
  return ActionReturn(
405
+ fallback_args,
406
+ type=self.name,
407
+ errmsg=f'invalid API: {name}',
408
+ state=ActionStatusCode.API_ERROR)
409
  try:
410
  inputs = self._parser.parse_inputs(inputs, name)
411
  except ParseError as exc:
412
+ return ActionReturn(
413
+ fallback_args,
414
+ type=self.name,
415
+ errmsg=exc.err_msg,
416
+ state=ActionStatusCode.ARGS_ERROR)
417
  try:
418
  outputs = await getattr(self, name)(**inputs)
419
  except Exception as exc:
420
+ return ActionReturn(
421
+ inputs,
422
+ type=self.name,
423
+ errmsg=str(exc),
424
+ state=ActionStatusCode.API_ERROR)
425
  if isinstance(outputs, ActionReturn):
426
  action_return = outputs
427
  if not action_return.args:
lagent/actions/weather_query.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from lagent.actions.base_action import BaseAction, tool_api
4
+ from lagent.schema import ActionReturn, ActionStatusCode
5
+
6
+ class WeatherQuery(BaseAction):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.api_key = os.getenv("weather_token")
10
+ print(self.api_key)
11
+ if not self.api_key:
12
+ raise EnvironmentError("未找到环境变量 'token'。请设置你的和风天气 API Key 到 'weather_token' 环境变量中,比如export weather_token='xxx' ")
13
+
14
+ @tool_api
15
+ def run(self, location: str) -> dict:
16
+ """
17
+ 查询实时天气信息。
18
+
19
+ Args:
20
+ location (str): 要查询的地点名称、LocationID 或经纬度坐标(如 "101010100" 或 "116.41,39.92")。
21
+
22
+ Returns:
23
+ dict: 包含天气信息的字典
24
+ * location: 地点名称
25
+ * weather: 天气状况
26
+ * temperature: 当前温度
27
+ * wind_direction: 风向
28
+ * wind_speed: 风速(公里/小时)
29
+ * humidity: 相对湿度(%)
30
+ * report_time: 数据报告时间
31
+ """
32
+ try:
33
+ # 如果 location 不是坐标格式(例如 "116.41,39.92"),则调用 GeoAPI 获取 LocationID
34
+ if not ("," in location and location.replace(",", "").replace(".", "").isdigit()):
35
+ # 使用 GeoAPI 获取 LocationID
36
+ geo_url = f"https://geoapi.qweather.com/v2/city/lookup?location={location}&key={self.api_key}"
37
+ geo_response = requests.get(geo_url)
38
+ geo_data = geo_response.json()
39
+
40
+ if geo_data.get("code") != "200" or not geo_data.get("location"):
41
+ raise Exception(f"GeoAPI 返回错误码:{geo_data.get('code')} 或未找到位置")
42
+
43
+ location = geo_data["location"][0]["id"]
44
+
45
+ # 构建天气查询的 API 请求 URL
46
+ weather_url = f"https://devapi.qweather.com/v7/weather/now?location={location}&key={self.api_key}"
47
+ response = requests.get(weather_url)
48
+ data = response.json()
49
+
50
+ # 检查 API 响应码
51
+ if data.get("code") != "200":
52
+ raise Exception(f"Weather API 返回错误码:{data.get('code')}")
53
+
54
+ # 解析和组织天气信息
55
+ weather_info = {
56
+ "location": location,
57
+ "weather": data["now"]["text"],
58
+ "temperature": data["now"]["temp"] + "°C",
59
+ "wind_direction": data["now"]["windDir"],
60
+ "wind_speed": data["now"]["windSpeed"] + " km/h",
61
+ "humidity": data["now"]["humidity"] + "%",
62
+ "report_time": data["updateTime"]
63
+ }
64
+
65
+ return {"result": weather_info}
66
+
67
+ except Exception as exc:
68
+ return ActionReturn(
69
+ errmsg=f"WeatherQuery 异常:{exc}",
70
+ state=ActionStatusCode.HTTP_ERROR
71
+ )
lagent/actions/web_browser.py CHANGED
@@ -18,6 +18,7 @@ import requests
18
  from asyncache import cached as acached
19
  from bs4 import BeautifulSoup
20
  from cachetools import TTLCache, cached
 
21
 
22
  from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
23
  from lagent.actions.parser import BaseParser, JsonParser
@@ -34,11 +35,12 @@ class BaseSearch:
34
  filtered_results = {}
35
  count = 0
36
  for url, snippet, title in results:
37
- if all(domain not in url for domain in self.black_list) and not url.endswith('.pdf'):
 
38
  filtered_results[count] = {
39
  'url': url,
40
  'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
41
- 'title': title,
42
  }
43
  count += 1
44
  if count >= self.topk:
@@ -48,17 +50,15 @@ class BaseSearch:
48
 
49
  class DuckDuckGoSearch(BaseSearch):
50
 
51
- def __init__(
52
- self,
53
- topk: int = 3,
54
- black_list: List[str] = [
55
- 'enoN',
56
- 'youtube.com',
57
- 'bilibili.com',
58
- 'researchgate.net',
59
- ],
60
- **kwargs,
61
- ):
62
  self.proxy = kwargs.get('proxy')
63
  self.timeout = kwargs.get('timeout', 30)
64
  super().__init__(topk, black_list)
@@ -67,39 +67,40 @@ class DuckDuckGoSearch(BaseSearch):
67
  def search(self, query: str, max_retry: int = 3) -> dict:
68
  for attempt in range(max_retry):
69
  try:
70
- response = self._call_ddgs(query, timeout=self.timeout, proxy=self.proxy)
 
71
  return self._parse_response(response)
72
  except Exception as e:
73
  logging.exception(str(e))
74
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
75
  time.sleep(random.randint(2, 5))
76
- raise Exception('Failed to get search results from DuckDuckGo after retries.')
 
77
 
78
  @acached(cache=TTLCache(maxsize=100, ttl=600))
79
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
80
- from duckduckgo_search import AsyncDDGS
81
-
82
  for attempt in range(max_retry):
83
  try:
84
  ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
85
- response = await ddgs.text(query.strip("'"), max_results=10)
86
  return self._parse_response(response)
87
  except Exception as e:
88
  if isinstance(e, asyncio.TimeoutError):
89
  logging.exception('Request to DDGS timed out.')
90
  logging.exception(str(e))
91
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
92
  await asyncio.sleep(random.randint(2, 5))
93
- raise Exception('Failed to get search results from DuckDuckGo after retries.')
 
94
 
95
  async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
96
- from duckduckgo_search import DDGS
97
-
98
  ddgs = DDGS(**kwargs)
99
  try:
100
  response = await asyncio.wait_for(
101
- asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10), timeout=self.timeout
102
- )
103
  return response
104
  except asyncio.TimeoutError:
105
  logging.exception('Request to DDGS timed out.')
@@ -109,35 +110,34 @@ class DuckDuckGoSearch(BaseSearch):
109
  loop = asyncio.new_event_loop()
110
  asyncio.set_event_loop(loop)
111
  try:
112
- response = loop.run_until_complete(self._async_call_ddgs(query, **kwargs))
 
113
  return response
114
  finally:
115
  loop.close()
116
 
117
- def _parse_response(self, response: List[dict]) -> dict:
118
  raw_results = []
119
  for item in response:
120
  raw_results.append(
121
- (item['href'], item['description'] if 'description' in item else item['body'], item['title'])
122
- )
123
  return self._filter_results(raw_results)
124
 
125
 
126
  class BingSearch(BaseSearch):
127
 
128
- def __init__(
129
- self,
130
- api_key: str,
131
- region: str = 'zh-CN',
132
- topk: int = 3,
133
- black_list: List[str] = [
134
- 'enoN',
135
- 'youtube.com',
136
- 'bilibili.com',
137
- 'researchgate.net',
138
- ],
139
- **kwargs,
140
- ):
141
  self.api_key = api_key
142
  self.market = region
143
  self.proxy = kwargs.get('proxy')
@@ -151,9 +151,11 @@ class BingSearch(BaseSearch):
151
  return self._parse_response(response)
152
  except Exception as e:
153
  logging.exception(str(e))
154
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
155
  time.sleep(random.randint(2, 5))
156
- raise Exception('Failed to get search results from Bing Search after retries.')
 
157
 
158
  @acached(cache=TTLCache(maxsize=100, ttl=600))
159
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
@@ -163,15 +165,18 @@ class BingSearch(BaseSearch):
163
  return self._parse_response(response)
164
  except Exception as e:
165
  logging.exception(str(e))
166
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
167
  await asyncio.sleep(random.randint(2, 5))
168
- raise Exception('Failed to get search results from Bing Search after retries.')
 
169
 
170
  def _call_bing_api(self, query: str) -> dict:
171
  endpoint = 'https://api.bing.microsoft.com/v7.0/search'
172
  params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
173
  headers = {'Ocp-Apim-Subscription-Key': self.api_key}
174
- response = requests.get(endpoint, headers=headers, params=params, proxies=self.proxy)
 
175
  response.raise_for_status()
176
  return response.json()
177
 
@@ -181,25 +186,32 @@ class BingSearch(BaseSearch):
181
  headers = {'Ocp-Apim-Subscription-Key': self.api_key}
182
  async with aiohttp.ClientSession(raise_for_status=True) as session:
183
  async with session.get(
184
- endpoint,
185
- headers=headers,
186
- params=params,
187
- proxy=self.proxy and (self.proxy.get('http') or self.proxy.get('https')),
188
- ) as resp:
189
  return await resp.json()
190
 
191
  def _parse_response(self, response: dict) -> dict:
192
- webpages = {w['id']: w for w in response.get('webPages', {}).get('value', [])}
 
 
 
193
  raw_results = []
194
 
195
- for item in response.get('rankingResponse', {}).get('mainline', {}).get('items', []):
 
196
  if item['answerType'] == 'WebPages':
197
  webpage = webpages.get(item['value']['id'])
198
  if webpage:
199
- raw_results.append((webpage['url'], webpage['snippet'], webpage['name']))
200
- elif item['answerType'] == 'News' and item['value']['id'] == response.get('news', {}).get('id'):
 
 
201
  for news in response.get('news', {}).get('value', []):
202
- raw_results.append((news['url'], news['description'], news['name']))
 
203
 
204
  return self._filter_results(raw_results)
205
 
@@ -218,27 +230,24 @@ class BraveSearch(BaseSearch):
218
  topk (int): The number of search results returned in response from API search results.
219
  region (str): The country code string. Specifies the country where the search results come from.
220
  language (str): The language code string. Specifies the preferred language for the search results.
221
- extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the
222
- search results.
223
  **kwargs: Any other parameters related to the Brave Search API. Find more details at
224
  https://api.search.brave.com/app/documentation/web-search/get-started.
225
  """
226
 
227
- def __init__(
228
- self,
229
- api_key: str,
230
- region: str = 'ALL',
231
- language: str = 'zh-hans',
232
- extra_snippests: bool = True,
233
- topk: int = 3,
234
- black_list: List[str] = [
235
- 'enoN',
236
- 'youtube.com',
237
- 'bilibili.com',
238
- 'researchgate.net',
239
- ],
240
- **kwargs,
241
- ):
242
  self.api_key = api_key
243
  self.market = region
244
  self.proxy = kwargs.get('proxy')
@@ -256,9 +265,11 @@ class BraveSearch(BaseSearch):
256
  return self._parse_response(response)
257
  except Exception as e:
258
  logging.exception(str(e))
259
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
260
  time.sleep(random.randint(2, 5))
261
- raise Exception('Failed to get search results from Brave Search after retries.')
 
262
 
263
  @acached(cache=TTLCache(maxsize=100, ttl=600))
264
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
@@ -268,9 +279,11 @@ class BraveSearch(BaseSearch):
268
  return self._parse_response(response)
269
  except Exception as e:
270
  logging.exception(str(e))
271
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
272
  await asyncio.sleep(random.randint(2, 5))
273
- raise Exception('Failed to get search results from Brave Search after retries.')
 
274
 
275
  def _call_brave_api(self, query: str) -> dict:
276
  endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
@@ -280,10 +293,17 @@ class BraveSearch(BaseSearch):
280
  'search_lang': self.language,
281
  'extra_snippets': self.extra_snippests,
282
  'count': self.topk,
283
- **{key: value for key, value in self.kwargs.items() if value is not None},
 
 
 
284
  }
285
- headers = {'X-Subscription-Token': self.api_key or '', 'Accept': 'application/json'}
286
- response = requests.get(endpoint, headers=headers, params=params, proxies=self.proxy)
 
 
 
 
287
  response.raise_for_status()
288
  return response.json()
289
 
@@ -295,16 +315,22 @@ class BraveSearch(BaseSearch):
295
  'search_lang': self.language,
296
  'extra_snippets': self.extra_snippests,
297
  'count': self.topk,
298
- **{key: value for key, value in self.kwargs.items() if value is not None},
 
 
 
 
 
 
 
299
  }
300
- headers = {'X-Subscription-Token': self.api_key or '', 'Accept': 'application/json'}
301
  async with aiohttp.ClientSession(raise_for_status=True) as session:
302
  async with session.get(
303
- endpoint,
304
- headers=headers,
305
- params=params,
306
- proxy=self.proxy and (self.proxy.get('http') or self.proxy.get('https')),
307
- ) as resp:
308
  return await resp.json()
309
 
310
  def _parse_response(self, response: dict) -> dict:
@@ -315,13 +341,15 @@ class BraveSearch(BaseSearch):
315
  raw_results = []
316
 
317
  for item in filtered_result:
318
- raw_results.append(
319
- (
320
- item.get('url', ''),
321
- ' '.join(filter(None, [item.get('description'), *item.get('extra_snippets', [])])),
322
- item.get('title', ''),
323
- )
324
- )
 
 
325
  return self._filter_results(raw_results)
326
 
327
 
@@ -348,18 +376,16 @@ class GoogleSearch(BaseSearch):
348
  'search': 'organic',
349
  }
350
 
351
- def __init__(
352
- self,
353
- api_key: str,
354
- topk: int = 3,
355
- black_list: List[str] = [
356
- 'enoN',
357
- 'youtube.com',
358
- 'bilibili.com',
359
- 'researchgate.net',
360
- ],
361
- **kwargs,
362
- ):
363
  self.api_key = api_key
364
  self.proxy = kwargs.get('proxy')
365
  self.search_type = kwargs.get('search_type', 'search')
@@ -374,9 +400,12 @@ class GoogleSearch(BaseSearch):
374
  return self._parse_response(response)
375
  except Exception as e:
376
  logging.exception(str(e))
377
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
378
  time.sleep(random.randint(2, 5))
379
- raise Exception('Failed to get search results from Google Serper Search after retries.')
 
 
380
 
381
  @acached(cache=TTLCache(maxsize=100, ttl=600))
382
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
@@ -386,19 +415,29 @@ class GoogleSearch(BaseSearch):
386
  return self._parse_response(response)
387
  except Exception as e:
388
  logging.exception(str(e))
389
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
390
  await asyncio.sleep(random.randint(2, 5))
391
- raise Exception('Failed to get search results from Google Serper Search after retries.')
 
 
392
 
393
  def _call_serper_api(self, query: str) -> dict:
394
  endpoint = f'https://google.serper.dev/{self.search_type}'
395
  params = {
396
  'q': query,
397
  'num': self.topk,
398
- **{key: value for key, value in self.kwargs.items() if value is not None},
 
 
 
399
  }
400
- headers = {'X-API-KEY': self.api_key or '', 'Content-Type': 'application/json'}
401
- response = requests.get(endpoint, headers=headers, params=params, proxies=self.proxy)
 
 
 
 
402
  response.raise_for_status()
403
  return response.json()
404
 
@@ -407,16 +446,22 @@ class GoogleSearch(BaseSearch):
407
  params = {
408
  'q': query,
409
  'num': self.topk,
410
- **{key: value for key, value in self.kwargs.items() if value is not None},
 
 
 
 
 
 
 
411
  }
412
- headers = {'X-API-KEY': self.api_key or '', 'Content-Type': 'application/json'}
413
  async with aiohttp.ClientSession(raise_for_status=True) as session:
414
  async with session.get(
415
- endpoint,
416
- headers=headers,
417
- params=params,
418
- proxy=self.proxy and (self.proxy.get('http') or self.proxy.get('https')),
419
- ) as resp:
420
  return await resp.json()
421
 
422
  def _parse_response(self, response: dict) -> dict:
@@ -427,34 +472,33 @@ class GoogleSearch(BaseSearch):
427
  if answer_box.get('answer'):
428
  raw_results.append(('', answer_box.get('answer'), ''))
429
  elif answer_box.get('snippet'):
430
- raw_results.append(('', answer_box.get('snippet').replace('\n', ' '), ''))
 
431
  elif answer_box.get('snippetHighlighted'):
432
- raw_results.append(('', answer_box.get('snippetHighlighted'), ''))
 
433
 
434
  if response.get('knowledgeGraph'):
435
  kg = response.get('knowledgeGraph', {})
436
  description = kg.get('description', '')
437
- attributes = '. '.join(f'{attribute}: {value}' for attribute, value in kg.get('attributes', {}).items())
 
 
438
  raw_results.append(
439
- (
440
- kg.get('descriptionLink', ''),
441
- f'{description}. {attributes}' if attributes else description,
442
- f"{kg.get('title', '')}: {kg.get('type', '')}.",
443
- )
444
- )
445
-
446
- for result in response[self.result_key_for_type[self.search_type]][: self.topk]:
447
  description = result.get('snippet', '')
448
  attributes = '. '.join(
449
- f'{attribute}: {value}' for attribute, value in result.get('attributes', {}).items()
450
- )
451
  raw_results.append(
452
- (
453
- result.get('link', ''),
454
- f'{description}. {attributes}' if attributes else description,
455
- result.get('title', ''),
456
- )
457
- )
458
 
459
  return self._filter_results(raw_results)
460
 
@@ -485,27 +529,25 @@ class TencentSearch(BaseSearch):
485
  Supports multiple values separated by commas. Example: `30010255`.
486
  """
487
 
488
- def __init__(
489
- self,
490
- secret_id: str = 'Your SecretId',
491
- secret_key: str = 'Your SecretKey',
492
- api_key: str = '',
493
- action: str = 'SearchCommon',
494
- version: str = '2020-12-29',
495
- service: str = 'tms',
496
- host: str = 'tms.tencentcloudapi.com',
497
- topk: int = 3,
498
- tsn: int = None,
499
- insite: str = None,
500
- category: str = None,
501
- vrid: str = None,
502
- black_list: List[str] = [
503
- 'enoN',
504
- 'youtube.com',
505
- 'bilibili.com',
506
- 'researchgate.net',
507
- ],
508
- ):
509
  self.secret_id = secret_id
510
  self.secret_key = secret_key
511
  self.api_key = api_key
@@ -527,9 +569,11 @@ class TencentSearch(BaseSearch):
527
  return self._parse_response(response)
528
  except Exception as e:
529
  logging.exception(str(e))
530
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
531
  time.sleep(random.randint(2, 5))
532
- raise Exception('Failed to get search results from Bing Search after retries.')
 
533
 
534
  @acached(cache=TTLCache(maxsize=100, ttl=600))
535
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
@@ -539,9 +583,11 @@ class TencentSearch(BaseSearch):
539
  return self._parse_response(response)
540
  except Exception as e:
541
  logging.exception(str(e))
542
- warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}')
 
543
  await asyncio.sleep(random.randint(2, 5))
544
- raise Exception('Failed to get search results from Bing Search after retries.')
 
545
 
546
  def _get_headers_and_payload(self, query: str) -> tuple:
547
 
@@ -571,47 +617,33 @@ class TencentSearch(BaseSearch):
571
  ct = 'application/json; charset=utf-8'
572
  canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
573
  signed_headers = 'content-type;host;x-tc-action'
574
- hashed_request_payload = hashlib.sha256(payload.encode('utf-8')).hexdigest()
 
575
  canonical_request = (
576
- http_request_method
577
- + '\n'
578
- + canonical_uri
579
- + '\n'
580
- + canonical_querystring
581
- + '\n'
582
- + canonical_headers
583
- + '\n'
584
- + signed_headers
585
- + '\n'
586
- + hashed_request_payload
587
- )
588
 
589
  # ************* 步骤 2:拼接待签名字符串 *************
590
  credential_scope = date + '/' + self.service + '/' + 'tc3_request'
591
- hashed_canonical_request = hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()
592
- string_to_sign = algorithm + '\n' + str(timestamp) + '\n' + credential_scope + '\n' + hashed_canonical_request
 
 
 
593
 
594
  # ************* 步骤 3:计算签名 *************
595
  secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
596
  secret_service = sign(secret_date, self.service)
597
  secret_signing = sign(secret_service, 'tc3_request')
598
- signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest()
 
599
 
600
  # ************* 步骤 4:拼接 Authorization *************
601
  authorization = (
602
- algorithm
603
- + ' '
604
- + 'Credential='
605
- + self.secret_id
606
- + '/'
607
- + credential_scope
608
- + ', '
609
- + 'SignedHeaders='
610
- + signed_headers
611
- + ', '
612
- + 'Signature='
613
- + signature
614
- )
615
 
616
  # ************* 步骤 5:构造并发起请求 *************
617
  headers = {
@@ -620,7 +652,7 @@ class TencentSearch(BaseSearch):
620
  'Host': self.host,
621
  'X-TC-Action': self.action,
622
  'X-TC-Timestamp': str(timestamp),
623
- 'X-TC-Version': self.version,
624
  }
625
  # if self.region:
626
  # headers["X-TC-Region"] = self.region
@@ -638,14 +670,16 @@ class TencentSearch(BaseSearch):
638
  except Exception as e:
639
  logging.warning(str(e))
640
  import ast
641
-
642
  resp = ast.literal_eval(resp)
643
  return resp.get('Response', dict())
644
 
645
  async def _async_call_tencent_api(self, query: str):
646
  headers, payload = self._get_headers_and_payload(query)
647
  async with aiohttp.ClientSession(raise_for_status=True) as session:
648
- async with session.post('https://' + self.host.lstrip('/'), headers=headers, data=payload) as resp:
 
 
 
649
  return (await resp.json()).get('Response', {})
650
 
651
  def _parse_response(self, response: dict) -> dict:
@@ -654,7 +688,8 @@ class TencentSearch(BaseSearch):
654
  display = json.loads(item['Display'])
655
  if not display['url']:
656
  continue
657
- raw_results.append((display['url'], display['content'] or display['abstract_info'], display['title']))
 
658
  return self._filter_results(raw_results)
659
 
660
 
@@ -680,8 +715,8 @@ class ContentFetcher:
680
  async def afetch(self, url: str) -> Tuple[bool, str]:
681
  try:
682
  async with aiohttp.ClientSession(
683
- raise_for_status=True, timeout=aiohttp.ClientTimeout(self.timeout)
684
- ) as session:
685
  async with session.get(url) as resp:
686
  html = await resp.text(errors='ignore')
687
  text = BeautifulSoup(html, 'html.parser').get_text()
@@ -692,24 +727,24 @@ class ContentFetcher:
692
 
693
 
694
  class WebBrowser(BaseAction):
695
- """Wrapper around the Web Browser Tool."""
696
-
697
- def __init__(
698
- self,
699
- searcher_type: str = 'DuckDuckGoSearch',
700
- timeout: int = 5,
701
- black_list: Optional[List[str]] = [
702
- 'enoN',
703
- 'youtube.com',
704
- 'bilibili.com',
705
- 'researchgate.net',
706
- ],
707
- topk: int = 20,
708
- description: Optional[dict] = None,
709
- parser: Type[BaseParser] = JsonParser,
710
- **kwargs,
711
- ):
712
- self.searcher = eval(searcher_type)(black_list=black_list, topk=topk, **kwargs)
713
  self.fetcher = ContentFetcher(timeout=timeout)
714
  self.search_results = None
715
  super().__init__(description, parser)
@@ -724,7 +759,10 @@ class WebBrowser(BaseAction):
724
  search_results = {}
725
 
726
  with ThreadPoolExecutor() as executor:
727
- future_to_query = {executor.submit(self.searcher.search, q): q for q in queries}
 
 
 
728
 
729
  for future in as_completed(future_to_query):
730
  query = future_to_query[future]
@@ -737,9 +775,13 @@ class WebBrowser(BaseAction):
737
  if result['url'] not in search_results:
738
  search_results[result['url']] = result
739
  else:
740
- search_results[result['url']]['summ'] += f"\n{result['summ']}"
 
741
 
742
- self.search_results = {idx: result for idx, result in enumerate(search_results.values())}
 
 
 
743
  return self.search_results
744
 
745
  @tool_api
@@ -756,8 +798,7 @@ class WebBrowser(BaseAction):
756
  with ThreadPoolExecutor() as executor:
757
  future_to_id = {
758
  executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id
759
- for select_id in select_ids
760
- if select_id in self.search_results
761
  }
762
  for future in as_completed(future_to_id):
763
  select_id = future_to_id[future]
@@ -767,8 +808,10 @@ class WebBrowser(BaseAction):
767
  warnings.warn(f'{select_id} generated an exception: {exc}')
768
  else:
769
  if web_success:
770
- self.search_results[select_id]['content'] = web_content[:8192]
771
- new_search_results[select_id] = self.search_results[select_id].copy()
 
 
772
  new_search_results[select_id].pop('summ')
773
 
774
  return new_search_results
@@ -784,12 +827,13 @@ class WebBrowser(BaseAction):
784
 
785
 
786
  class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
787
- """Wrapper around the Web Browser Tool."""
 
788
 
789
  @tool_api
790
  async def search(self, query: Union[str, List[str]]) -> dict:
791
  """BING search API
792
-
793
  Args:
794
  query (List[str]): list of search query strings
795
  """
@@ -812,9 +856,13 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
812
  if result['url'] not in search_results:
813
  search_results[result['url']] = result
814
  else:
815
- search_results[result['url']]['summ'] += f"\n{result['summ']}"
 
816
 
817
- self.search_results = {idx: result for idx, result in enumerate(search_results.values())}
 
 
 
818
  return self.search_results
819
 
820
  @tool_api
@@ -831,7 +879,8 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
831
  tasks = []
832
  for select_id in select_ids:
833
  if select_id in self.search_results:
834
- task = asyncio.create_task(self.fetcher.afetch(self.search_results[select_id]['url']))
 
835
  task.select_id = select_id
836
  tasks.append(task)
837
  async for future in async_as_completed(tasks):
@@ -842,8 +891,10 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
842
  warnings.warn(f'{select_id} generated an exception: {exc}')
843
  else:
844
  if web_success:
845
- self.search_results[select_id]['content'] = web_content[:8192]
846
- new_search_results[select_id] = self.search_results[select_id].copy()
 
 
847
  new_search_results[select_id].pop('summ')
848
  return new_search_results
849
 
 
18
  from asyncache import cached as acached
19
  from bs4 import BeautifulSoup
20
  from cachetools import TTLCache, cached
21
+ from duckduckgo_search import DDGS, AsyncDDGS
22
 
23
  from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
24
  from lagent.actions.parser import BaseParser, JsonParser
 
35
  filtered_results = {}
36
  count = 0
37
  for url, snippet, title in results:
38
+ if all(domain not in url
39
+ for domain in self.black_list) and not url.endswith('.pdf'):
40
  filtered_results[count] = {
41
  'url': url,
42
  'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
43
+ 'title': title
44
  }
45
  count += 1
46
  if count >= self.topk:
 
50
 
51
  class DuckDuckGoSearch(BaseSearch):
52
 
53
+ def __init__(self,
54
+ topk: int = 3,
55
+ black_list: List[str] = [
56
+ 'enoN',
57
+ 'youtube.com',
58
+ 'bilibili.com',
59
+ 'researchgate.net',
60
+ ],
61
+ **kwargs):
 
 
62
  self.proxy = kwargs.get('proxy')
63
  self.timeout = kwargs.get('timeout', 30)
64
  super().__init__(topk, black_list)
 
67
  def search(self, query: str, max_retry: int = 3) -> dict:
68
  for attempt in range(max_retry):
69
  try:
70
+ response = self._call_ddgs(
71
+ query, timeout=self.timeout, proxy=self.proxy)
72
  return self._parse_response(response)
73
  except Exception as e:
74
  logging.exception(str(e))
75
+ warnings.warn(
76
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
77
  time.sleep(random.randint(2, 5))
78
+ raise Exception(
79
+ 'Failed to get search results from DuckDuckGo after retries.')
80
 
81
  @acached(cache=TTLCache(maxsize=100, ttl=600))
82
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
 
 
83
  for attempt in range(max_retry):
84
  try:
85
  ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
86
+ response = await ddgs.atext(query.strip("'"), max_results=10)
87
  return self._parse_response(response)
88
  except Exception as e:
89
  if isinstance(e, asyncio.TimeoutError):
90
  logging.exception('Request to DDGS timed out.')
91
  logging.exception(str(e))
92
+ warnings.warn(
93
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
94
  await asyncio.sleep(random.randint(2, 5))
95
+ raise Exception(
96
+ 'Failed to get search results from DuckDuckGo after retries.')
97
 
98
  async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
 
 
99
  ddgs = DDGS(**kwargs)
100
  try:
101
  response = await asyncio.wait_for(
102
+ asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10),
103
+ timeout=self.timeout)
104
  return response
105
  except asyncio.TimeoutError:
106
  logging.exception('Request to DDGS timed out.')
 
110
  loop = asyncio.new_event_loop()
111
  asyncio.set_event_loop(loop)
112
  try:
113
+ response = loop.run_until_complete(
114
+ self._async_call_ddgs(query, **kwargs))
115
  return response
116
  finally:
117
  loop.close()
118
 
119
+ def _parse_response(self, response: dict) -> dict:
120
  raw_results = []
121
  for item in response:
122
  raw_results.append(
123
+ (item['href'], item['description']
124
+ if 'description' in item else item['body'], item['title']))
125
  return self._filter_results(raw_results)
126
 
127
 
128
  class BingSearch(BaseSearch):
129
 
130
+ def __init__(self,
131
+ api_key: str,
132
+ region: str = 'zh-CN',
133
+ topk: int = 3,
134
+ black_list: List[str] = [
135
+ 'enoN',
136
+ 'youtube.com',
137
+ 'bilibili.com',
138
+ 'researchgate.net',
139
+ ],
140
+ **kwargs):
 
 
141
  self.api_key = api_key
142
  self.market = region
143
  self.proxy = kwargs.get('proxy')
 
151
  return self._parse_response(response)
152
  except Exception as e:
153
  logging.exception(str(e))
154
+ warnings.warn(
155
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
156
  time.sleep(random.randint(2, 5))
157
+ raise Exception(
158
+ 'Failed to get search results from Bing Search after retries.')
159
 
160
  @acached(cache=TTLCache(maxsize=100, ttl=600))
161
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
 
165
  return self._parse_response(response)
166
  except Exception as e:
167
  logging.exception(str(e))
168
+ warnings.warn(
169
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
170
  await asyncio.sleep(random.randint(2, 5))
171
+ raise Exception(
172
+ 'Failed to get search results from Bing Search after retries.')
173
 
174
  def _call_bing_api(self, query: str) -> dict:
175
  endpoint = 'https://api.bing.microsoft.com/v7.0/search'
176
  params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
177
  headers = {'Ocp-Apim-Subscription-Key': self.api_key}
178
+ response = requests.get(
179
+ endpoint, headers=headers, params=params, proxies=self.proxy)
180
  response.raise_for_status()
181
  return response.json()
182
 
 
186
  headers = {'Ocp-Apim-Subscription-Key': self.api_key}
187
  async with aiohttp.ClientSession(raise_for_status=True) as session:
188
  async with session.get(
189
+ endpoint,
190
+ headers=headers,
191
+ params=params,
192
+ proxy=self.proxy and
193
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
194
  return await resp.json()
195
 
196
  def _parse_response(self, response: dict) -> dict:
197
+ webpages = {
198
+ w['id']: w
199
+ for w in response.get('webPages', {}).get('value', [])
200
+ }
201
  raw_results = []
202
 
203
+ for item in response.get('rankingResponse',
204
+ {}).get('mainline', {}).get('items', []):
205
  if item['answerType'] == 'WebPages':
206
  webpage = webpages.get(item['value']['id'])
207
  if webpage:
208
+ raw_results.append(
209
+ (webpage['url'], webpage['snippet'], webpage['name']))
210
+ elif item['answerType'] == 'News' and item['value'][
211
+ 'id'] == response.get('news', {}).get('id'):
212
  for news in response.get('news', {}).get('value', []):
213
+ raw_results.append(
214
+ (news['url'], news['description'], news['name']))
215
 
216
  return self._filter_results(raw_results)
217
 
 
230
  topk (int): The number of search results returned in response from API search results.
231
  region (str): The country code string. Specifies the country where the search results come from.
232
  language (str): The language code string. Specifies the preferred language for the search results.
233
+ extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results.
 
234
  **kwargs: Any other parameters related to the Brave Search API. Find more details at
235
  https://api.search.brave.com/app/documentation/web-search/get-started.
236
  """
237
 
238
+ def __init__(self,
239
+ api_key: str,
240
+ region: str = 'ALL',
241
+ language: str = 'zh-hans',
242
+ extra_snippests: bool = True,
243
+ topk: int = 3,
244
+ black_list: List[str] = [
245
+ 'enoN',
246
+ 'youtube.com',
247
+ 'bilibili.com',
248
+ 'researchgate.net',
249
+ ],
250
+ **kwargs):
 
 
251
  self.api_key = api_key
252
  self.market = region
253
  self.proxy = kwargs.get('proxy')
 
265
  return self._parse_response(response)
266
  except Exception as e:
267
  logging.exception(str(e))
268
+ warnings.warn(
269
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
270
  time.sleep(random.randint(2, 5))
271
+ raise Exception(
272
+ 'Failed to get search results from Brave Search after retries.')
273
 
274
  @acached(cache=TTLCache(maxsize=100, ttl=600))
275
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
 
279
  return self._parse_response(response)
280
  except Exception as e:
281
  logging.exception(str(e))
282
+ warnings.warn(
283
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
284
  await asyncio.sleep(random.randint(2, 5))
285
+ raise Exception(
286
+ 'Failed to get search results from Brave Search after retries.')
287
 
288
  def _call_brave_api(self, query: str) -> dict:
289
  endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
 
293
  'search_lang': self.language,
294
  'extra_snippets': self.extra_snippests,
295
  'count': self.topk,
296
+ **{
297
+ key: value
298
+ for key, value in self.kwargs.items() if value is not None
299
+ },
300
  }
301
+ headers = {
302
+ 'X-Subscription-Token': self.api_key or '',
303
+ 'Accept': 'application/json'
304
+ }
305
+ response = requests.get(
306
+ endpoint, headers=headers, params=params, proxies=self.proxy)
307
  response.raise_for_status()
308
  return response.json()
309
 
 
315
  'search_lang': self.language,
316
  'extra_snippets': self.extra_snippests,
317
  'count': self.topk,
318
+ **{
319
+ key: value
320
+ for key, value in self.kwargs.items() if value is not None
321
+ },
322
+ }
323
+ headers = {
324
+ 'X-Subscription-Token': self.api_key or '',
325
+ 'Accept': 'application/json'
326
  }
 
327
  async with aiohttp.ClientSession(raise_for_status=True) as session:
328
  async with session.get(
329
+ endpoint,
330
+ headers=headers,
331
+ params=params,
332
+ proxy=self.proxy and
333
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
334
  return await resp.json()
335
 
336
  def _parse_response(self, response: dict) -> dict:
 
341
  raw_results = []
342
 
343
  for item in filtered_result:
344
+ raw_results.append((
345
+ item.get('url', ''),
346
+ ' '.join(
347
+ filter(None, [
348
+ item.get('description'),
349
+ *item.get('extra_snippets', [])
350
+ ])),
351
+ item.get('title', ''),
352
+ ))
353
  return self._filter_results(raw_results)
354
 
355
 
 
376
  'search': 'organic',
377
  }
378
 
379
+ def __init__(self,
380
+ api_key: str,
381
+ topk: int = 3,
382
+ black_list: List[str] = [
383
+ 'enoN',
384
+ 'youtube.com',
385
+ 'bilibili.com',
386
+ 'researchgate.net',
387
+ ],
388
+ **kwargs):
 
 
389
  self.api_key = api_key
390
  self.proxy = kwargs.get('proxy')
391
  self.search_type = kwargs.get('search_type', 'search')
 
400
  return self._parse_response(response)
401
  except Exception as e:
402
  logging.exception(str(e))
403
+ warnings.warn(
404
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
405
  time.sleep(random.randint(2, 5))
406
+ raise Exception(
407
+ 'Failed to get search results from Google Serper Search after retries.'
408
+ )
409
 
410
  @acached(cache=TTLCache(maxsize=100, ttl=600))
411
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
 
415
  return self._parse_response(response)
416
  except Exception as e:
417
  logging.exception(str(e))
418
+ warnings.warn(
419
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
420
  await asyncio.sleep(random.randint(2, 5))
421
+ raise Exception(
422
+ 'Failed to get search results from Google Serper Search after retries.'
423
+ )
424
 
425
  def _call_serper_api(self, query: str) -> dict:
426
  endpoint = f'https://google.serper.dev/{self.search_type}'
427
  params = {
428
  'q': query,
429
  'num': self.topk,
430
+ **{
431
+ key: value
432
+ for key, value in self.kwargs.items() if value is not None
433
+ },
434
  }
435
+ headers = {
436
+ 'X-API-KEY': self.api_key or '',
437
+ 'Content-Type': 'application/json'
438
+ }
439
+ response = requests.get(
440
+ endpoint, headers=headers, params=params, proxies=self.proxy)
441
  response.raise_for_status()
442
  return response.json()
443
 
 
446
  params = {
447
  'q': query,
448
  'num': self.topk,
449
+ **{
450
+ key: value
451
+ for key, value in self.kwargs.items() if value is not None
452
+ },
453
+ }
454
+ headers = {
455
+ 'X-API-KEY': self.api_key or '',
456
+ 'Content-Type': 'application/json'
457
  }
 
458
  async with aiohttp.ClientSession(raise_for_status=True) as session:
459
  async with session.get(
460
+ endpoint,
461
+ headers=headers,
462
+ params=params,
463
+ proxy=self.proxy and
464
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
465
  return await resp.json()
466
 
467
  def _parse_response(self, response: dict) -> dict:
 
472
  if answer_box.get('answer'):
473
  raw_results.append(('', answer_box.get('answer'), ''))
474
  elif answer_box.get('snippet'):
475
+ raw_results.append(
476
+ ('', answer_box.get('snippet').replace('\n', ' '), ''))
477
  elif answer_box.get('snippetHighlighted'):
478
+ raw_results.append(
479
+ ('', answer_box.get('snippetHighlighted'), ''))
480
 
481
  if response.get('knowledgeGraph'):
482
  kg = response.get('knowledgeGraph', {})
483
  description = kg.get('description', '')
484
+ attributes = '. '.join(
485
+ f'{attribute}: {value}'
486
+ for attribute, value in kg.get('attributes', {}).items())
487
  raw_results.append(
488
+ (kg.get('descriptionLink', ''),
489
+ f'{description}. {attributes}' if attributes else description,
490
+ f"{kg.get('title', '')}: {kg.get('type', '')}."))
491
+
492
+ for result in response[self.result_key_for_type[
493
+ self.search_type]][:self.topk]:
 
 
494
  description = result.get('snippet', '')
495
  attributes = '. '.join(
496
+ f'{attribute}: {value}'
497
+ for attribute, value in result.get('attributes', {}).items())
498
  raw_results.append(
499
+ (result.get('link', ''),
500
+ f'{description}. {attributes}' if attributes else description,
501
+ result.get('title', '')))
 
 
 
502
 
503
  return self._filter_results(raw_results)
504
 
 
529
  Supports multiple values separated by commas. Example: `30010255`.
530
  """
531
 
532
+ def __init__(self,
533
+ secret_id: str = 'Your SecretId',
534
+ secret_key: str = 'Your SecretKey',
535
+ api_key: str = '',
536
+ action: str = 'SearchCommon',
537
+ version: str = '2020-12-29',
538
+ service: str = 'tms',
539
+ host: str = 'tms.tencentcloudapi.com',
540
+ topk: int = 3,
541
+ tsn: int = None,
542
+ insite: str = None,
543
+ category: str = None,
544
+ vrid: str = None,
545
+ black_list: List[str] = [
546
+ 'enoN',
547
+ 'youtube.com',
548
+ 'bilibili.com',
549
+ 'researchgate.net',
550
+ ]):
 
 
551
  self.secret_id = secret_id
552
  self.secret_key = secret_key
553
  self.api_key = api_key
 
569
  return self._parse_response(response)
570
  except Exception as e:
571
  logging.exception(str(e))
572
+ warnings.warn(
573
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
574
  time.sleep(random.randint(2, 5))
575
+ raise Exception(
576
+ 'Failed to get search results from Bing Search after retries.')
577
 
578
  @acached(cache=TTLCache(maxsize=100, ttl=600))
579
  async def asearch(self, query: str, max_retry: int = 3) -> dict:
 
583
  return self._parse_response(response)
584
  except Exception as e:
585
  logging.exception(str(e))
586
+ warnings.warn(
587
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
588
  await asyncio.sleep(random.randint(2, 5))
589
+ raise Exception(
590
+ 'Failed to get search results from Bing Search after retries.')
591
 
592
  def _get_headers_and_payload(self, query: str) -> tuple:
593
 
 
617
  ct = 'application/json; charset=utf-8'
618
  canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
619
  signed_headers = 'content-type;host;x-tc-action'
620
+ hashed_request_payload = hashlib.sha256(
621
+ payload.encode('utf-8')).hexdigest()
622
  canonical_request = (
623
+ http_request_method + '\n' + canonical_uri + '\n' +
624
+ canonical_querystring + '\n' + canonical_headers + '\n' +
625
+ signed_headers + '\n' + hashed_request_payload)
 
 
 
 
 
 
 
 
 
626
 
627
  # ************* 步骤 2:拼接待签名字符串 *************
628
  credential_scope = date + '/' + self.service + '/' + 'tc3_request'
629
+ hashed_canonical_request = hashlib.sha256(
630
+ canonical_request.encode('utf-8')).hexdigest()
631
+ string_to_sign = (
632
+ algorithm + '\n' + str(timestamp) + '\n' + credential_scope +
633
+ '\n' + hashed_canonical_request)
634
 
635
  # ************* 步骤 3:计算签名 *************
636
  secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
637
  secret_service = sign(secret_date, self.service)
638
  secret_signing = sign(secret_service, 'tc3_request')
639
+ signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'),
640
+ hashlib.sha256).hexdigest()
641
 
642
  # ************* 步骤 4:拼接 Authorization *************
643
  authorization = (
644
+ algorithm + ' ' + 'Credential=' + self.secret_id + '/' +
645
+ credential_scope + ', ' + 'SignedHeaders=' + signed_headers +
646
+ ', ' + 'Signature=' + signature)
 
 
 
 
 
 
 
 
 
 
647
 
648
  # ************* 步骤 5:构造并发起请求 *************
649
  headers = {
 
652
  'Host': self.host,
653
  'X-TC-Action': self.action,
654
  'X-TC-Timestamp': str(timestamp),
655
+ 'X-TC-Version': self.version
656
  }
657
  # if self.region:
658
  # headers["X-TC-Region"] = self.region
 
670
  except Exception as e:
671
  logging.warning(str(e))
672
  import ast
 
673
  resp = ast.literal_eval(resp)
674
  return resp.get('Response', dict())
675
 
676
  async def _async_call_tencent_api(self, query: str):
677
  headers, payload = self._get_headers_and_payload(query)
678
  async with aiohttp.ClientSession(raise_for_status=True) as session:
679
+ async with session.post(
680
+ 'https://' + self.host.lstrip('/'),
681
+ headers=headers,
682
+ data=payload) as resp:
683
  return (await resp.json()).get('Response', {})
684
 
685
  def _parse_response(self, response: dict) -> dict:
 
688
  display = json.loads(item['Display'])
689
  if not display['url']:
690
  continue
691
+ raw_results.append((display['url'], display['content']
692
+ or display['abstract_info'], display['title']))
693
  return self._filter_results(raw_results)
694
 
695
 
 
715
  async def afetch(self, url: str) -> Tuple[bool, str]:
716
  try:
717
  async with aiohttp.ClientSession(
718
+ raise_for_status=True,
719
+ timeout=aiohttp.ClientTimeout(self.timeout)) as session:
720
  async with session.get(url) as resp:
721
  html = await resp.text(errors='ignore')
722
  text = BeautifulSoup(html, 'html.parser').get_text()
 
727
 
728
 
729
  class WebBrowser(BaseAction):
730
+ """Wrapper around the Web Browser Tool.
731
+ """
732
+
733
+ def __init__(self,
734
+ searcher_type: str = 'DuckDuckGoSearch',
735
+ timeout: int = 5,
736
+ black_list: Optional[List[str]] = [
737
+ 'enoN',
738
+ 'youtube.com',
739
+ 'bilibili.com',
740
+ 'researchgate.net',
741
+ ],
742
+ topk: int = 20,
743
+ description: Optional[dict] = None,
744
+ parser: Type[BaseParser] = JsonParser,
745
+ **kwargs):
746
+ self.searcher = eval(searcher_type)(
747
+ black_list=black_list, topk=topk, **kwargs)
748
  self.fetcher = ContentFetcher(timeout=timeout)
749
  self.search_results = None
750
  super().__init__(description, parser)
 
759
  search_results = {}
760
 
761
  with ThreadPoolExecutor() as executor:
762
+ future_to_query = {
763
+ executor.submit(self.searcher.search, q): q
764
+ for q in queries
765
+ }
766
 
767
  for future in as_completed(future_to_query):
768
  query = future_to_query[future]
 
775
  if result['url'] not in search_results:
776
  search_results[result['url']] = result
777
  else:
778
+ search_results[
779
+ result['url']]['summ'] += f"\n{result['summ']}"
780
 
781
+ self.search_results = {
782
+ idx: result
783
+ for idx, result in enumerate(search_results.values())
784
+ }
785
  return self.search_results
786
 
787
  @tool_api
 
798
  with ThreadPoolExecutor() as executor:
799
  future_to_id = {
800
  executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id
801
+ for select_id in select_ids if select_id in self.search_results
 
802
  }
803
  for future in as_completed(future_to_id):
804
  select_id = future_to_id[future]
 
808
  warnings.warn(f'{select_id} generated an exception: {exc}')
809
  else:
810
  if web_success:
811
+ self.search_results[select_id][
812
+ 'content'] = web_content[:8192]
813
+ new_search_results[select_id] = self.search_results[
814
+ select_id].copy()
815
  new_search_results[select_id].pop('summ')
816
 
817
  return new_search_results
 
827
 
828
 
829
  class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
830
+ """Wrapper around the Web Browser Tool.
831
+ """
832
 
833
  @tool_api
834
  async def search(self, query: Union[str, List[str]]) -> dict:
835
  """BING search API
836
+
837
  Args:
838
  query (List[str]): list of search query strings
839
  """
 
856
  if result['url'] not in search_results:
857
  search_results[result['url']] = result
858
  else:
859
+ search_results[
860
+ result['url']]['summ'] += f"\n{result['summ']}"
861
 
862
+ self.search_results = {
863
+ idx: result
864
+ for idx, result in enumerate(search_results.values())
865
+ }
866
  return self.search_results
867
 
868
  @tool_api
 
879
  tasks = []
880
  for select_id in select_ids:
881
  if select_id in self.search_results:
882
+ task = asyncio.create_task(
883
+ self.fetcher.afetch(self.search_results[select_id]['url']))
884
  task.select_id = select_id
885
  tasks.append(task)
886
  async for future in async_as_completed(tasks):
 
891
  warnings.warn(f'{select_id} generated an exception: {exc}')
892
  else:
893
  if web_success:
894
+ self.search_results[select_id][
895
+ 'content'] = web_content[:8192]
896
+ new_search_results[select_id] = self.search_results[
897
+ select_id].copy()
898
  new_search_results[select_id].pop('summ')
899
  return new_search_results
900
 
lagent/agents/__init__.py CHANGED
@@ -1,33 +1,9 @@
1
- from .agent import (
2
- Agent,
3
- AgentDict,
4
- AgentList,
5
- AsyncAgent,
6
- AsyncSequential,
7
- AsyncStreamingAgent,
8
- AsyncStreamingSequential,
9
- Sequential,
10
- StreamingAgent,
11
- StreamingSequential,
12
- )
13
  from .react import AsyncReAct, ReAct
14
  from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder
15
 
16
  __all__ = [
17
- 'Agent',
18
- 'AgentDict',
19
- 'AgentList',
20
- 'AsyncAgent',
21
- 'AgentForInternLM',
22
- 'AsyncAgentForInternLM',
23
- 'MathCoder',
24
- 'AsyncMathCoder',
25
- 'ReAct',
26
- 'AsyncReAct',
27
- 'Sequential',
28
- 'AsyncSequential',
29
- 'StreamingAgent',
30
- 'StreamingSequential',
31
- 'AsyncStreamingAgent',
32
- 'AsyncStreamingSequential',
33
  ]
 
1
+ from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential
 
 
 
 
 
 
 
 
 
 
 
2
  from .react import AsyncReAct, ReAct
3
  from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder
4
 
5
  __all__ = [
6
+ 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
7
+ 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
8
+ 'AsyncReAct', 'Sequential', 'AsyncSequential'
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ]
lagent/agents/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (547 Bytes). View file
 
lagent/agents/__pycache__/agent.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
lagent/agents/__pycache__/react.cpython-310.pyc ADDED
Binary file (4.85 kB). View file
 
lagent/agents/__pycache__/stream.cpython-310.pyc ADDED
Binary file (8.95 kB). View file
 
lagent/agents/agent.py CHANGED
@@ -3,7 +3,7 @@ import warnings
3
  from collections import OrderedDict, UserDict, UserList, abc
4
  from functools import wraps
5
  from itertools import chain, repeat
6
- from typing import Any, AsyncGenerator, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union
7
 
8
  from lagent.agents.aggregator import DefaultAggregator
9
  from lagent.hooks import Hook, RemovableHandle
@@ -11,7 +11,7 @@ from lagent.llms import BaseLLM
11
  from lagent.memory import Memory, MemoryManager
12
  from lagent.prompts.parsers import StrParser
13
  from lagent.prompts.prompt_template import PromptTemplate
14
- from lagent.schema import AgentMessage, ModelStatusCode
15
  from lagent.utils import create_object
16
 
17
 
@@ -63,17 +63,29 @@ class Agent:
63
  if self.memory:
64
  self.memory.add(message, session_id=session_id)
65
 
66
- def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
 
 
 
 
 
67
  # message.receiver = self.name
68
- message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
 
 
 
69
  for hook in self._hooks.values():
70
  result = hook.before_agent(self, message, session_id)
71
  if result:
72
  message = result
73
  self.update_memory(message, session_id=session_id)
74
- response_message = self.forward(*message, session_id=session_id, **kwargs)
 
75
  if not isinstance(response_message, AgentMessage):
76
- response_message = AgentMessage(sender=self.name, content=response_message)
 
 
 
77
  self.update_memory(response_message, session_id=session_id)
78
  response_message = copy.deepcopy(response_message)
79
  for hook in self._hooks.values():
@@ -82,14 +94,25 @@ class Agent:
82
  response_message = result
83
  return response_message
84
 
85
- def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]:
 
 
 
86
  formatted_messages = self.aggregator.aggregate(
87
- self.memory.get(session_id), self.name, self.output_format, self.template
 
 
 
88
  )
89
  llm_response = self.llm.chat(formatted_messages, **kwargs)
90
  if self.output_format:
91
- formatted_messages = self.output_format.parse_response(llm_response)
92
- return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages)
 
 
 
 
 
93
  return llm_response
94
 
95
  def __setattr__(self, __name: str, __value: Any) -> None:
@@ -142,8 +165,12 @@ class Agent:
142
  self._hooks[handle.id] = hook
143
  return handle
144
 
145
- def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = False):
146
- assert not (keypath and recursive), 'keypath and recursive can\'t be used together'
 
 
 
 
147
  if keypath:
148
  keys, agent = keypath.split('.'), self
149
  for key in keys:
@@ -162,13 +189,15 @@ class Agent:
162
  def __repr__(self):
163
 
164
  def _rcsv_repr(agent, n_indent=1):
165
- res = agent.__class__.__name__ + (f"(name='{agent.name}')" if agent.name else '')
 
166
  modules = [
167
  f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
168
  for name, agent in getattr(agent, '_agents', {}).items()
169
  ]
170
  if modules:
171
- res += '(\n' + '\n'.join(modules) + f'\n{(n_indent - 1) * " "})'
 
172
  elif not res.endswith(')'):
173
  res += '()'
174
  return res
@@ -176,18 +205,28 @@ class Agent:
176
  return _rcsv_repr(self)
177
 
178
 
179
- class AsyncAgentMixin:
180
 
181
- async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
182
- message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
 
 
 
 
 
 
183
  for hook in self._hooks.values():
184
  result = hook.before_agent(self, message, session_id)
185
  if result:
186
  message = result
187
  self.update_memory(message, session_id=session_id)
188
- response_message = await self.forward(*message, session_id=session_id, **kwargs)
 
189
  if not isinstance(response_message, AgentMessage):
190
- response_message = AgentMessage(sender=self.name, content=response_message)
 
 
 
191
  self.update_memory(response_message, session_id=session_id)
192
  response_message = copy.deepcopy(response_message)
193
  for hook in self._hooks.values():
@@ -196,133 +235,40 @@ class AsyncAgentMixin:
196
  response_message = result
197
  return response_message
198
 
199
- async def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]:
 
 
 
200
  formatted_messages = self.aggregator.aggregate(
201
- self.memory.get(session_id), self.name, self.output_format, self.template
 
 
 
202
  )
203
- llm_response = await self.llm.chat(formatted_messages, session_id, **kwargs)
 
204
  if self.output_format:
205
- formatted_messages = self.output_format.parse_response(llm_response)
206
- return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages)
207
- return llm_response
208
-
209
-
210
- class AsyncAgent(AsyncAgentMixin, Agent):
211
- """Asynchronous variant of the Agent class"""
212
-
213
- pass
214
-
215
-
216
- class StreamingAgentMixin:
217
- """Component that makes agent calling output a streaming response."""
218
-
219
- def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Generator[AgentMessage, None, None]:
220
- message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
221
- for hook in self._hooks.values():
222
- result = hook.before_agent(self, message, session_id)
223
- if result:
224
- message = result
225
- self.update_memory(message, session_id=session_id)
226
- response_message = AgentMessage(sender=self.name, content="")
227
- for response_message in self.forward(*message, session_id=session_id, **kwargs):
228
- if not isinstance(response_message, AgentMessage):
229
- model_state, response = response_message
230
- response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state)
231
- yield response_message.model_copy()
232
- self.update_memory(response_message, session_id=session_id)
233
- response_message = copy.deepcopy(response_message)
234
- for hook in self._hooks.values():
235
- result = hook.after_agent(self, response_message, session_id)
236
- if result:
237
- response_message = result
238
- yield response_message
239
-
240
- def forward(
241
- self, *message: AgentMessage, session_id=0, **kwargs
242
- ) -> Generator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None, None]:
243
- formatted_messages = self.aggregator.aggregate(
244
- self.memory.get(session_id), self.name, self.output_format, self.template
245
- )
246
- for model_state, response, *_ in self.llm.stream_chat(formatted_messages, session_id=session_id, **kwargs):
247
- yield (
248
- AgentMessage(
249
- sender=self.name,
250
- content=response,
251
- formatted=self.output_format.parse_response(response),
252
- stream_state=model_state,
253
- )
254
- if self.output_format
255
- else (model_state, response)
256
- )
257
-
258
-
259
- class AsyncStreamingAgentMixin:
260
- """Component that makes asynchronous agent calling output a streaming response."""
261
-
262
- async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AsyncGenerator[AgentMessage, None]:
263
- message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
264
- for hook in self._hooks.values():
265
- result = hook.before_agent(self, message, session_id)
266
- if result:
267
- message = result
268
- self.update_memory(message, session_id=session_id)
269
- response_message = AgentMessage(sender=self.name, content="")
270
- async for response_message in self.forward(*message, session_id=session_id, **kwargs):
271
- if not isinstance(response_message, AgentMessage):
272
- model_state, response = response_message
273
- response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state)
274
- yield response_message.model_copy()
275
- self.update_memory(response_message, session_id=session_id)
276
- response_message = copy.deepcopy(response_message)
277
- for hook in self._hooks.values():
278
- result = hook.after_agent(self, response_message, session_id)
279
- if result:
280
- response_message = result
281
- yield response_message
282
-
283
- async def forward(
284
- self, *message: AgentMessage, session_id=0, **kwargs
285
- ) -> AsyncGenerator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None]:
286
- formatted_messages = self.aggregator.aggregate(
287
- self.memory.get(session_id), self.name, self.output_format, self.template
288
- )
289
- async for model_state, response, *_ in self.llm.stream_chat(
290
- formatted_messages, session_id=session_id, **kwargs
291
- ):
292
- yield (
293
- AgentMessage(
294
- sender=self.name,
295
- content=response,
296
- formatted=self.output_format.parse_response(response),
297
- stream_state=model_state,
298
- )
299
- if self.output_format
300
- else (model_state, response)
301
  )
302
-
303
-
304
- class StreamingAgent(StreamingAgentMixin, Agent):
305
- """Streaming variant of the Agent class"""
306
-
307
- pass
308
-
309
-
310
- class AsyncStreamingAgent(AsyncStreamingAgentMixin, Agent):
311
- """Streaming variant of the AsyncAgent class"""
312
-
313
- pass
314
 
315
 
316
  class Sequential(Agent):
317
- """Sequential is an agent container that forwards messages to each agent
318
  in the order they are added."""
319
 
320
- def __init__(self, *agents: Union[Agent, Iterable], **kwargs):
321
  super().__init__(**kwargs)
322
  self._agents = OrderedDict()
323
  if not agents:
324
  raise ValueError('At least one agent should be provided')
325
- if isinstance(agents[0], Iterable) and not isinstance(agents[0], Agent):
 
326
  if not agents[0]:
327
  raise ValueError('At least one agent should be provided')
328
  agents = agents[0]
@@ -333,11 +279,17 @@ class Sequential(Agent):
333
  key, agent = agent
334
  self.add_agent(key, agent)
335
 
336
- def add_agent(self, name: str, agent: Agent):
337
- assert isinstance(agent, Agent), f'{type(agent)} is not an Agent subclass'
 
 
338
  self._agents[str(name)] = agent
339
 
340
- def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs) -> AgentMessage:
 
 
 
 
341
  assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
342
  if exit_at is None:
343
  exit_at = len(self) - 1
@@ -345,7 +297,7 @@ class Sequential(Agent):
345
  for _ in range(exit_at + 1):
346
  agent = next(iterator)
347
  if isinstance(message, AgentMessage):
348
- message = (message,)
349
  message = agent(*message, session_id=session_id, **kwargs)
350
  return message
351
 
@@ -359,11 +311,13 @@ class Sequential(Agent):
359
  return len(self._agents)
360
 
361
 
362
- class AsyncSequential(AsyncAgentMixin, Sequential):
363
 
364
- async def forward(
365
- self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs
366
- ) -> AgentMessage:
 
 
367
  assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
368
  if exit_at is None:
369
  exit_at = len(self) - 1
@@ -371,43 +325,11 @@ class AsyncSequential(AsyncAgentMixin, Sequential):
371
  for _ in range(exit_at + 1):
372
  agent = next(iterator)
373
  if isinstance(message, AgentMessage):
374
- message = (message,)
375
  message = await agent(*message, session_id=session_id, **kwargs)
376
  return message
377
 
378
 
379
- class StreamingSequential(StreamingAgentMixin, Sequential):
380
- """Streaming variant of the Sequential class"""
381
-
382
- def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs):
383
- assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
384
- if exit_at is None:
385
- exit_at = len(self) - 1
386
- iterator = chain.from_iterable(repeat(self._agents.values()))
387
- for _ in range(exit_at + 1):
388
- agent = next(iterator)
389
- if isinstance(message, AgentMessage):
390
- message = (message,)
391
- for message in agent(*message, session_id=session_id, **kwargs):
392
- yield message
393
-
394
-
395
- class AsyncStreamingSequential(AsyncStreamingAgentMixin, Sequential):
396
- """Streaming variant of the AsyncSequential class"""
397
-
398
- async def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs):
399
- assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
400
- if exit_at is None:
401
- exit_at = len(self) - 1
402
- iterator = chain.from_iterable(repeat(self._agents.values()))
403
- for _ in range(exit_at + 1):
404
- agent = next(iterator)
405
- if isinstance(message, AgentMessage):
406
- message = (message,)
407
- async for message in agent(*message, session_id=session_id, **kwargs):
408
- yield message
409
-
410
-
411
  class AgentContainerMixin:
412
 
413
  def __init_subclass__(cls):
@@ -427,28 +349,33 @@ class AgentContainerMixin:
427
 
428
  ret = func(self, *args, **kwargs)
429
  agents = OrderedDict()
430
- for k, item in self.data.items() if isinstance(self.data, abc.Mapping) else enumerate(self.data):
431
- if isinstance(self.data, abc.Mapping) and not isinstance(k, str):
 
 
432
  _backup(data)
433
- raise KeyError(f'agent name should be a string, got {type(k)}')
 
434
  if isinstance(k, str) and '.' in k:
435
  _backup(data)
436
- raise KeyError(f'agent name can\'t contain ".", got {k}')
437
- if not isinstance(item, Agent):
 
438
  _backup(data)
439
- raise TypeError(f'{type(item)} is not an Agent subclass')
 
 
440
  agents[str(k)] = item
441
  self._agents = agents
442
  return ret
443
 
444
  return wrapped_func
445
 
446
- # fmt: off
447
  for method in [
448
- 'append', 'sort', 'reverse', 'pop', 'clear', 'update',
449
- 'insert', 'extend', 'remove', '__init__', '__setitem__',
450
- '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
451
- '__imul__', '__rmul__'
452
  ]:
453
  if hasattr(cls, method):
454
  setattr(cls, method, wrap_api(getattr(cls, method)))
@@ -456,7 +383,8 @@ class AgentContainerMixin:
456
 
457
  class AgentList(Agent, UserList, AgentContainerMixin):
458
 
459
- def __init__(self, agents: Optional[Iterable[Agent]] = None):
 
460
  Agent.__init__(self, memory=None)
461
  UserList.__init__(self, agents)
462
  self.name = None
@@ -464,7 +392,9 @@ class AgentList(Agent, UserList, AgentContainerMixin):
464
 
465
  class AgentDict(Agent, UserDict, AgentContainerMixin):
466
 
467
- def __init__(self, agents: Optional[Mapping[str, Agent]] = None):
 
 
468
  Agent.__init__(self, memory=None)
469
  UserDict.__init__(self, agents)
470
  self.name = None
 
3
  from collections import OrderedDict, UserDict, UserList, abc
4
  from functools import wraps
5
  from itertools import chain, repeat
6
+ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
7
 
8
  from lagent.agents.aggregator import DefaultAggregator
9
  from lagent.hooks import Hook, RemovableHandle
 
11
  from lagent.memory import Memory, MemoryManager
12
  from lagent.prompts.parsers import StrParser
13
  from lagent.prompts.prompt_template import PromptTemplate
14
+ from lagent.schema import AgentMessage
15
  from lagent.utils import create_object
16
 
17
 
 
63
  if self.memory:
64
  self.memory.add(message, session_id=session_id)
65
 
66
+ def __call__(
67
+ self,
68
+ *message: Union[str, AgentMessage, List[AgentMessage]],
69
+ session_id=0,
70
+ **kwargs,
71
+ ) -> AgentMessage:
72
  # message.receiver = self.name
73
+ message = [
74
+ AgentMessage(sender='user', content=m)
75
+ if isinstance(m, str) else copy.deepcopy(m) for m in message
76
+ ]
77
  for hook in self._hooks.values():
78
  result = hook.before_agent(self, message, session_id)
79
  if result:
80
  message = result
81
  self.update_memory(message, session_id=session_id)
82
+ response_message = self.forward(
83
+ *message, session_id=session_id, **kwargs)
84
  if not isinstance(response_message, AgentMessage):
85
+ response_message = AgentMessage(
86
+ sender=self.name,
87
+ content=response_message,
88
+ )
89
  self.update_memory(response_message, session_id=session_id)
90
  response_message = copy.deepcopy(response_message)
91
  for hook in self._hooks.values():
 
94
  response_message = result
95
  return response_message
96
 
97
+ def forward(self,
98
+ *message: AgentMessage,
99
+ session_id=0,
100
+ **kwargs) -> Union[AgentMessage, str]:
101
  formatted_messages = self.aggregator.aggregate(
102
+ self.memory.get(session_id),
103
+ self.name,
104
+ self.output_format,
105
+ self.template,
106
  )
107
  llm_response = self.llm.chat(formatted_messages, **kwargs)
108
  if self.output_format:
109
+ formatted_messages = self.output_format.parse_response(
110
+ llm_response)
111
+ return AgentMessage(
112
+ sender=self.name,
113
+ content=llm_response,
114
+ formatted=formatted_messages,
115
+ )
116
  return llm_response
117
 
118
  def __setattr__(self, __name: str, __value: Any) -> None:
 
165
  self._hooks[handle.id] = hook
166
  return handle
167
 
168
+ def reset(self,
169
+ session_id=0,
170
+ keypath: Optional[str] = None,
171
+ recursive: bool = False):
172
+ assert not (keypath and
173
+ recursive), 'keypath and recursive can\'t be used together'
174
  if keypath:
175
  keys, agent = keypath.split('.'), self
176
  for key in keys:
 
189
  def __repr__(self):
190
 
191
  def _rcsv_repr(agent, n_indent=1):
192
+ res = agent.__class__.__name__ + (f"(name='{agent.name}')"
193
+ if agent.name else '')
194
  modules = [
195
  f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
196
  for name, agent in getattr(agent, '_agents', {}).items()
197
  ]
198
  if modules:
199
+ res += '(\n' + '\n'.join(
200
+ modules) + f'\n{(n_indent - 1) * " "})'
201
  elif not res.endswith(')'):
202
  res += '()'
203
  return res
 
205
  return _rcsv_repr(self)
206
 
207
 
208
+ class AsyncAgent(Agent):
209
 
210
+ async def __call__(self,
211
+ *message: AgentMessage | List[AgentMessage],
212
+ session_id=0,
213
+ **kwargs) -> AgentMessage:
214
+ message = [
215
+ AgentMessage(sender='user', content=m)
216
+ if isinstance(m, str) else copy.deepcopy(m) for m in message
217
+ ]
218
  for hook in self._hooks.values():
219
  result = hook.before_agent(self, message, session_id)
220
  if result:
221
  message = result
222
  self.update_memory(message, session_id=session_id)
223
+ response_message = await self.forward(
224
+ *message, session_id=session_id, **kwargs)
225
  if not isinstance(response_message, AgentMessage):
226
+ response_message = AgentMessage(
227
+ sender=self.name,
228
+ content=response_message,
229
+ )
230
  self.update_memory(response_message, session_id=session_id)
231
  response_message = copy.deepcopy(response_message)
232
  for hook in self._hooks.values():
 
235
  response_message = result
236
  return response_message
237
 
238
+ async def forward(self,
239
+ *message: AgentMessage,
240
+ session_id=0,
241
+ **kwargs) -> Union[AgentMessage, str]:
242
  formatted_messages = self.aggregator.aggregate(
243
+ self.memory.get(session_id),
244
+ self.name,
245
+ self.output_format,
246
+ self.template,
247
  )
248
+ llm_response = await self.llm.chat(formatted_messages, session_id,
249
+ **kwargs)
250
  if self.output_format:
251
+ formatted_messages = self.output_format.parse_response(
252
+ llm_response)
253
+ return AgentMessage(
254
+ sender=self.name,
255
+ content=llm_response,
256
+ formatted=formatted_messages,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
258
+ return llm_response
 
 
 
 
 
 
 
 
 
 
 
259
 
260
 
261
  class Sequential(Agent):
262
+ """Sequential is an agent container that forwards messages to each agent
263
  in the order they are added."""
264
 
265
+ def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
266
  super().__init__(**kwargs)
267
  self._agents = OrderedDict()
268
  if not agents:
269
  raise ValueError('At least one agent should be provided')
270
+ if isinstance(agents[0],
271
+ Iterable) and not isinstance(agents[0], Agent):
272
  if not agents[0]:
273
  raise ValueError('At least one agent should be provided')
274
  agents = agents[0]
 
279
  key, agent = agent
280
  self.add_agent(key, agent)
281
 
282
+ def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]):
283
+ assert isinstance(
284
+ agent, (Agent, AsyncAgent
285
+ )), f'{type(agent)} is not an Agent or AsyncAgent subclass'
286
  self._agents[str(name)] = agent
287
 
288
+ def forward(self,
289
+ *message: AgentMessage,
290
+ session_id=0,
291
+ exit_at: Optional[int] = None,
292
+ **kwargs) -> AgentMessage:
293
  assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
294
  if exit_at is None:
295
  exit_at = len(self) - 1
 
297
  for _ in range(exit_at + 1):
298
  agent = next(iterator)
299
  if isinstance(message, AgentMessage):
300
+ message = (message, )
301
  message = agent(*message, session_id=session_id, **kwargs)
302
  return message
303
 
 
311
  return len(self._agents)
312
 
313
 
314
+ class AsyncSequential(Sequential, AsyncAgent):
315
 
316
+ async def forward(self,
317
+ *message: AgentMessage,
318
+ session_id=0,
319
+ exit_at: Optional[int] = None,
320
+ **kwargs) -> AgentMessage:
321
  assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
322
  if exit_at is None:
323
  exit_at = len(self) - 1
 
325
  for _ in range(exit_at + 1):
326
  agent = next(iterator)
327
  if isinstance(message, AgentMessage):
328
+ message = (message, )
329
  message = await agent(*message, session_id=session_id, **kwargs)
330
  return message
331
 
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  class AgentContainerMixin:
334
 
335
  def __init_subclass__(cls):
 
349
 
350
  ret = func(self, *args, **kwargs)
351
  agents = OrderedDict()
352
+ for k, item in (self.data.items() if isinstance(
353
+ self.data, abc.Mapping) else enumerate(self.data)):
354
+ if isinstance(self.data,
355
+ abc.Mapping) and not isinstance(k, str):
356
  _backup(data)
357
+ raise KeyError(
358
+ f'agent name should be a string, got {type(k)}')
359
  if isinstance(k, str) and '.' in k:
360
  _backup(data)
361
+ raise KeyError(
362
+ f'agent name can\'t contain ".", got {k}')
363
+ if not isinstance(item, (Agent, AsyncAgent)):
364
  _backup(data)
365
+ raise TypeError(
366
+ f'{type(item)} is not an Agent or AsyncAgent subclass'
367
+ )
368
  agents[str(k)] = item
369
  self._agents = agents
370
  return ret
371
 
372
  return wrapped_func
373
 
 
374
  for method in [
375
+ 'append', 'sort', 'reverse', 'pop', 'clear', 'update',
376
+ 'insert', 'extend', 'remove', '__init__', '__setitem__',
377
+ '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
378
+ '__imul__', '__rmul__'
379
  ]:
380
  if hasattr(cls, method):
381
  setattr(cls, method, wrap_api(getattr(cls, method)))
 
383
 
384
  class AgentList(Agent, UserList, AgentContainerMixin):
385
 
386
+ def __init__(self,
387
+ agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
388
  Agent.__init__(self, memory=None)
389
  UserList.__init__(self, agents)
390
  self.name = None
 
392
 
393
  class AgentDict(Agent, UserDict, AgentContainerMixin):
394
 
395
+ def __init__(self,
396
+ agents: Optional[Mapping[str, Union[Agent,
397
+ AsyncAgent]]] = None):
398
  Agent.__init__(self, memory=None)
399
  UserDict.__init__(self, agents)
400
  self.name = None
lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (305 Bytes). View file
 
lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
lagent/agents/react.py CHANGED
@@ -12,6 +12,7 @@ from lagent.memory import Memory
12
  from lagent.prompts.parsers.json_parser import JSONParser
13
  from lagent.prompts.prompt_template import PromptTemplate
14
  from lagent.schema import AgentMessage
 
15
 
16
  select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
17
  {action_info}
@@ -27,88 +28,96 @@ output_format_template = """如果使用工具请遵循以下格式回复:
27
 
28
  class ReAct(Agent):
29
 
30
- def __init__(
31
- self,
32
- llm: Union[BaseLLM, Dict],
33
- actions: Union[BaseAction, List[BaseAction]],
34
- template: Union[PromptTemplate, str] = None,
35
- memory: Dict = dict(type=Memory),
36
- output_format: Dict = dict(type=JSONParser),
37
- aggregator: Dict = dict(type=DefaultAggregator),
38
- hooks: List = [dict(type=ActionPreprocessor)],
39
- finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content
40
- or 'conclusion' in m.formatted,
41
- max_turn: int = 5,
42
- **kwargs
43
- ):
44
  self.max_turn = max_turn
45
  self.finish_condition = finish_condition
46
- self.actions = ActionExecutor(actions=actions, hooks=hooks)
47
- self.select_agent = Agent(
 
 
 
 
 
 
48
  llm=llm,
49
  template=template.format(
50
- action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction()
51
- ),
52
  output_format=output_format,
53
  memory=memory,
54
  aggregator=aggregator,
55
  hooks=hooks,
56
  )
 
57
  super().__init__(**kwargs)
58
 
59
- def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
60
  for _ in range(self.max_turn):
61
- message = self.select_agent(message, session_id=session_id, **kwargs)
62
  if self.finish_condition(message):
63
  return message
64
- message = self.actions(message, session_id=session_id)
65
  return message
66
 
67
 
68
  class AsyncReAct(AsyncAgent):
69
 
70
- def __init__(
71
- self,
72
- llm: Union[BaseLLM, Dict],
73
- actions: Union[BaseAction, List[BaseAction]],
74
- template: Union[PromptTemplate, str] = None,
75
- memory: Dict = dict(type=Memory),
76
- output_format: Dict = dict(type=JSONParser),
77
- aggregator: Dict = dict(type=DefaultAggregator),
78
- hooks: List = [dict(type=ActionPreprocessor)],
79
- finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content
80
- or 'conclusion' in m.formatted,
81
- max_turn: int = 5,
82
- **kwargs
83
- ):
84
  self.max_turn = max_turn
85
  self.finish_condition = finish_condition
86
- self.actions = AsyncActionExecutor(actions=actions, hooks=hooks)
87
- self.select_agent = AsyncAgent(
 
 
 
 
 
 
88
  llm=llm,
89
  template=template.format(
90
- action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction()
91
- ),
92
  output_format=output_format,
93
  memory=memory,
94
  aggregator=aggregator,
95
  hooks=hooks,
96
  )
 
97
  super().__init__(**kwargs)
98
 
99
- async def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
100
  for _ in range(self.max_turn):
101
- message = await self.select_agent(message, session_id=session_id, **kwargs)
102
  if self.finish_condition(message):
103
  return message
104
- message = await self.actions(message, session_id=session_id)
105
  return message
106
 
107
 
108
  if __name__ == '__main__':
109
- import asyncio
110
-
111
- from lagent.llms import GPTAPI, AsyncGPTAPI
112
 
113
  class ActionCall(BaseModel):
114
  name: str = Field(description='调用的函数名称')
@@ -116,49 +125,37 @@ if __name__ == '__main__':
116
 
117
  class ActionFormat(BaseModel):
118
  thought_process: str = Field(
119
- description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。'
120
- )
121
  action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。')
122
 
123
  class FinishFormat(BaseModel):
124
  thought_process: str = Field(
125
- description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。'
126
- )
127
  conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
128
 
129
  prompt_template = PromptTemplate(select_action_template)
130
- output_format = JSONParser(output_format_template, function_format=ActionFormat, finish_format=FinishFormat)
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  agent = ReAct(
133
- llm=dict(
134
- type=GPTAPI,
135
- model_type='gpt-4o-2024-05-13',
136
- max_new_tokens=4096,
137
- proxies=dict(),
138
- retry=1000,
139
- ),
140
  template=prompt_template,
141
  output_format=output_format,
142
- aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'),
143
- actions=[dict(type='lagent.actions.PythonInterpreter')],
144
  )
145
- response = agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
 
146
  print(response)
147
  response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
148
  print(response)
149
-
150
- async_agent = AsyncReAct(
151
- llm=dict(
152
- type=AsyncGPTAPI,
153
- model_type='gpt-4o-2024-05-13',
154
- max_new_tokens=4096,
155
- proxies=dict(),
156
- retry=1000,
157
- ),
158
- template=prompt_template,
159
- output_format=output_format,
160
- aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'),
161
- actions=[dict(type='lagent.actions.AsyncPythonInterpreter')],
162
- )
163
- response = asyncio.run(async_agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')))
164
- print(async_agent.state_dict())
 
12
  from lagent.prompts.parsers.json_parser import JSONParser
13
  from lagent.prompts.prompt_template import PromptTemplate
14
  from lagent.schema import AgentMessage
15
+ from lagent.utils import create_object
16
 
17
  select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
18
  {action_info}
 
28
 
29
  class ReAct(Agent):
30
 
31
+ def __init__(self,
32
+ llm: Union[BaseLLM, Dict],
33
+ actions: Union[BaseAction, List[BaseAction]],
34
+ template: Union[PromptTemplate, str] = None,
35
+ memory: Dict = dict(type=Memory),
36
+ output_format: Dict = dict(type=JSONParser),
37
+ aggregator: Dict = dict(type=DefaultAggregator),
38
+ hooks: List = [dict(type=ActionPreprocessor)],
39
+ finish_condition: Callable[[AgentMessage], bool] = lambda m:
40
+ 'conclusion' in m.content or 'conclusion' in m.formatted,
41
+ max_turn: int = 5,
42
+ **kwargs):
 
 
43
  self.max_turn = max_turn
44
  self.finish_condition = finish_condition
45
+ actions = dict(
46
+ type=ActionExecutor,
47
+ actions=actions,
48
+ hooks=hooks,
49
+ )
50
+ self.actions: ActionExecutor = create_object(actions)
51
+ select_agent = dict(
52
+ type=Agent,
53
  llm=llm,
54
  template=template.format(
55
+ action_info=json.dumps(self.actions.description()),
56
+ output_format=output_format.format_instruction()),
57
  output_format=output_format,
58
  memory=memory,
59
  aggregator=aggregator,
60
  hooks=hooks,
61
  )
62
+ self.select_agent = create_object(select_agent)
63
  super().__init__(**kwargs)
64
 
65
+ def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
66
  for _ in range(self.max_turn):
67
+ message = self.select_agent(message)
68
  if self.finish_condition(message):
69
  return message
70
+ message = self.actions(message)
71
  return message
72
 
73
 
74
  class AsyncReAct(AsyncAgent):
75
 
76
+ def __init__(self,
77
+ llm: Union[BaseLLM, Dict],
78
+ actions: Union[BaseAction, List[BaseAction]],
79
+ template: Union[PromptTemplate, str] = None,
80
+ memory: Dict = dict(type=Memory),
81
+ output_format: Dict = dict(type=JSONParser),
82
+ aggregator: Dict = dict(type=DefaultAggregator),
83
+ hooks: List = [dict(type=ActionPreprocessor)],
84
+ finish_condition: Callable[[AgentMessage], bool] = lambda m:
85
+ 'conclusion' in m.content or 'conclusion' in m.formatted,
86
+ max_turn: int = 5,
87
+ **kwargs):
 
 
88
  self.max_turn = max_turn
89
  self.finish_condition = finish_condition
90
+ actions = dict(
91
+ type=AsyncActionExecutor,
92
+ actions=actions,
93
+ hooks=hooks,
94
+ )
95
+ self.actions: AsyncActionExecutor = create_object(actions)
96
+ select_agent = dict(
97
+ type=AsyncAgent,
98
  llm=llm,
99
  template=template.format(
100
+ action_info=json.dumps(self.actions.description()),
101
+ output_format=output_format.format_instruction()),
102
  output_format=output_format,
103
  memory=memory,
104
  aggregator=aggregator,
105
  hooks=hooks,
106
  )
107
+ self.select_agent = create_object(select_agent)
108
  super().__init__(**kwargs)
109
 
110
+ async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
111
  for _ in range(self.max_turn):
112
+ message = await self.select_agent(message)
113
  if self.finish_condition(message):
114
  return message
115
+ message = await self.actions(message)
116
  return message
117
 
118
 
119
  if __name__ == '__main__':
120
+ from lagent.llms import GPTAPI
 
 
121
 
122
  class ActionCall(BaseModel):
123
  name: str = Field(description='调用的函数名称')
 
125
 
126
  class ActionFormat(BaseModel):
127
  thought_process: str = Field(
128
+ description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
 
129
  action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。')
130
 
131
  class FinishFormat(BaseModel):
132
  thought_process: str = Field(
133
+ description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
 
134
  conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
135
 
136
  prompt_template = PromptTemplate(select_action_template)
137
+ output_format = JSONParser(
138
+ output_format_template,
139
+ function_format=ActionFormat,
140
+ finish_format=FinishFormat)
141
+
142
+ llm = dict(
143
+ type=GPTAPI,
144
+ model_type='gpt-4o-2024-05-13',
145
+ key=None,
146
+ max_new_tokens=4096,
147
+ proxies=dict(),
148
+ retry=1000)
149
 
150
  agent = ReAct(
151
+ llm=llm,
 
 
 
 
 
 
152
  template=prompt_template,
153
  output_format=output_format,
154
+ aggregator=dict(type='DefaultAggregator'),
155
+ actions=[dict(type='PythonInterpreter')],
156
  )
157
+ response = agent(
158
+ AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
159
  print(response)
160
  response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
161
  print(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lagent/agents/stream.py CHANGED
@@ -15,30 +15,25 @@ from lagent.utils import create_object
15
 
16
  API_PREFIX = (
17
  "This is the subfunction for tool '{tool_name}', you can use this tool. "
18
- 'The description of this function is: \n{description}'
19
- )
20
 
21
- META_CN = '当开启工具以及代码时,根据需求选择合适的工具进行调用'
22
 
23
- INTERPRETER_CN = (
24
- '你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
25
- '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
26
- '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
27
- '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
28
- '文本处理和分析(比如文本解析和自然语言处理),'
29
- '机器学习和数据科学(用于展示模型训练和数据可视化),'
30
- '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。'
31
- )
32
 
33
- PLUGIN_CN = (
34
- '你可以使用如下工具:'
35
- '\n{prompt}\n'
36
- '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
37
- '同时注意你可以使用的工具,不要随意捏造!'
38
- )
39
 
40
 
41
- def get_plugin_prompt(actions, api_desc_template='{description}'):
42
  plugin_descriptions = []
43
  for action in actions if isinstance(actions, list) else [actions]:
44
  action = create_object(action)
@@ -46,9 +41,20 @@ def get_plugin_prompt(actions, api_desc_template='{description}'):
46
  if action.is_toolkit:
47
  for api in action_desc['api_list']:
48
  api['name'] = f"{action.name}.{api['name']}"
49
- api['description'] = api_desc_template.format(tool_name=action.name, description=api['description'])
 
 
 
 
 
50
  plugin_descriptions.append(api)
51
  else:
 
 
 
 
 
 
52
  plugin_descriptions.append(action_desc)
53
  return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
54
 
@@ -70,15 +76,17 @@ class AgentForInternLM(Agent):
70
  parsers=[
71
  dict(type=PluginParser, template=PLUGIN_CN),
72
  dict(type=InterpreterParser, template=INTERPRETER_CN),
73
- ],
74
- ),
75
  aggregator: Dict = dict(type=InternLMToolAggregator),
76
  action_hooks: List = [dict(type=InternLMActionProcessor)],
77
- finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
 
 
78
  max_turn: int = 4,
79
  **kwargs,
80
  ):
81
- self.agent = self._INTERNAL_AGENT_CLS(
 
82
  llm=llm,
83
  template=template,
84
  output_format=output_format,
@@ -86,18 +94,22 @@ class AgentForInternLM(Agent):
86
  aggregator=aggregator,
87
  hooks=kwargs.pop('hooks', None),
88
  )
89
- self.plugin_executor = plugins and ActionExecutor(plugins, hooks=action_hooks)
90
- self.interpreter_executor = interpreter and ActionExecutor(interpreter, hooks=action_hooks)
 
 
 
91
  if not (self.plugin_executor or self.interpreter_executor):
92
  warnings.warn(
93
  'Neither plugin nor interpreter executor is initialized. '
94
- 'An exception will be thrown when the agent call a tool.'
95
- )
96
  self.finish_condition = finish_condition
97
  self.max_turn = max_turn
98
  super().__init__(**kwargs)
99
 
100
  def forward(self, message: AgentMessage, session_id=0, **kwargs):
 
 
101
  for _ in range(self.max_turn):
102
  message = self.agent(message, session_id=session_id, **kwargs)
103
  assert isinstance(message.formatted, dict)
@@ -115,10 +127,15 @@ class AgentForInternLM(Agent):
115
  steps, tool_type = [], None
116
  for msg in self.agent.memory.get_memory(session_id):
117
  if msg.sender == self.agent.name:
118
- steps.append(dict(role='thought', content=msg.formatted['thought']))
 
119
  if msg.formatted['tool_type']:
120
  tool_type = msg.formatted['tool_type']
121
- steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type))
 
 
 
 
122
  elif msg.sender != 'user':
123
  feedback = dict(role='environment', content=msg.content)
124
  if tool_type:
@@ -132,22 +149,23 @@ class MathCoder(AgentForInternLM):
132
  def __init__(
133
  self,
134
  llm: Union[BaseLLM, Dict],
135
- interpreter: dict = dict(type=IPythonInteractive, timeout=20, max_out_len=8192),
 
136
  template: Union[str, dict, List[dict]] = None,
137
  memory: Dict = dict(type=Memory),
138
  output_format: Dict = dict(
139
  type=InterpreterParser,
140
- template=(
141
- 'Integrate step-by-step reasoning and Python code to solve math problems '
142
- 'using the following guidelines:\n'
143
- '- Analyze the question and write jupyter code to solve the problem;\n'
144
- r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
145
- 'units. \n'
146
- ),
147
- ),
148
  aggregator: Dict = dict(type=InternLMToolAggregator),
149
  action_hooks: List = [dict(type=InternLMActionProcessor)],
150
- finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
 
 
151
  max_turn: int = 6,
152
  **kwargs,
153
  ):
@@ -162,8 +180,7 @@ class MathCoder(AgentForInternLM):
162
  action_hooks=action_hooks,
163
  finish_condition=finish_condition,
164
  max_turn=max_turn,
165
- **kwargs,
166
- )
167
 
168
 
169
  class AsyncAgentForInternLM(AsyncAgent):
@@ -183,15 +200,17 @@ class AsyncAgentForInternLM(AsyncAgent):
183
  parsers=[
184
  dict(type=PluginParser, template=PLUGIN_CN),
185
  dict(type=InterpreterParser, template=INTERPRETER_CN),
186
- ],
187
- ),
188
  aggregator: Dict = dict(type=InternLMToolAggregator),
189
  action_hooks: List = [dict(type=InternLMActionProcessor)],
190
- finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
 
 
191
  max_turn: int = 4,
192
  **kwargs,
193
  ):
194
- self.agent = self._INTERNAL_AGENT_CLS(
 
195
  llm=llm,
196
  template=template,
197
  output_format=output_format,
@@ -199,20 +218,25 @@ class AsyncAgentForInternLM(AsyncAgent):
199
  aggregator=aggregator,
200
  hooks=kwargs.pop('hooks', None),
201
  )
202
- self.plugin_executor = plugins and AsyncActionExecutor(plugins, hooks=action_hooks)
203
- self.interpreter_executor = interpreter and AsyncActionExecutor(interpreter, hooks=action_hooks)
 
 
 
204
  if not (self.plugin_executor or self.interpreter_executor):
205
  warnings.warn(
206
  'Neither plugin nor interpreter executor is initialized. '
207
- 'An exception will be thrown when the agent call a tool.'
208
- )
209
  self.finish_condition = finish_condition
210
  self.max_turn = max_turn
211
  super().__init__(**kwargs)
212
 
213
  async def forward(self, message: AgentMessage, session_id=0, **kwargs):
 
 
214
  for _ in range(self.max_turn):
215
- message = await self.agent(message, session_id=session_id, **kwargs)
 
216
  assert isinstance(message.formatted, dict)
217
  if self.finish_condition(message):
218
  return message
@@ -228,10 +252,15 @@ class AsyncAgentForInternLM(AsyncAgent):
228
  steps, tool_type = [], None
229
  for msg in self.agent.memory.get_memory(session_id):
230
  if msg.sender == self.agent.name:
231
- steps.append(dict(role='thought', content=msg.formatted['thought']))
 
232
  if msg.formatted['tool_type']:
233
  tool_type = msg.formatted['tool_type']
234
- steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type))
 
 
 
 
235
  elif msg.sender != 'user':
236
  feedback = dict(role='environment', content=msg.content)
237
  if tool_type:
@@ -250,17 +279,17 @@ class AsyncMathCoder(AsyncAgentForInternLM):
250
  memory: Dict = dict(type=Memory),
251
  output_format: Dict = dict(
252
  type=InterpreterParser,
253
- template=(
254
- 'Integrate step-by-step reasoning and Python code to solve math problems '
255
- 'using the following guidelines:\n'
256
- '- Analyze the question and write jupyter code to solve the problem;\n'
257
- r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
258
- 'units. \n'
259
- ),
260
- ),
261
  aggregator: Dict = dict(type=InternLMToolAggregator),
262
  action_hooks: List = [dict(type=InternLMActionProcessor)],
263
- finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
 
 
264
  max_turn: int = 6,
265
  **kwargs,
266
  ):
@@ -275,13 +304,13 @@ class AsyncMathCoder(AsyncAgentForInternLM):
275
  action_hooks=action_hooks,
276
  finish_condition=finish_condition,
277
  max_turn=max_turn,
278
- **kwargs,
279
- )
280
 
281
  async def forward(self, message: AgentMessage, session_id=0, **kwargs):
282
  try:
283
  return await super().forward(message, session_id, **kwargs)
284
  finally:
285
- interpreter = next(iter(self.interpreter_executor.actions.values()))
 
286
  if interpreter.name == 'AsyncIPythonInterpreter':
287
  await interpreter.close_session(session_id)
 
15
 
16
  API_PREFIX = (
17
  "This is the subfunction for tool '{tool_name}', you can use this tool. "
18
+ 'The description of this function is: \n{description}')
 
19
 
20
+ META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用')
21
 
22
+ INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
23
+ '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
24
+ '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
25
+ '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
26
+ '文本处理和分析(比如文本解析和自然语言处理),'
27
+ '机器学习和数据科学(用于展示模型训练和数据可视化),'
28
+ '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。')
 
 
29
 
30
+ PLUGIN_CN = ('你可以使用如下工具:'
31
+ '\n{prompt}\n'
32
+ '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
33
+ '同时注意你可以使用的工具,不要随意捏造!')
 
 
34
 
35
 
36
+ def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
37
  plugin_descriptions = []
38
  for action in actions if isinstance(actions, list) else [actions]:
39
  action = create_object(action)
 
41
  if action.is_toolkit:
42
  for api in action_desc['api_list']:
43
  api['name'] = f"{action.name}.{api['name']}"
44
+ api['description'] = api_desc_template.format(
45
+ tool_name=action.name, description=api['description'])
46
+ api['parameters'] = [
47
+ param for param in api['parameters']
48
+ if param['name'] in api['required']
49
+ ]
50
  plugin_descriptions.append(api)
51
  else:
52
+ action_desc['description'] = api_desc_template.format(
53
+ tool_name=action.name, description=action_desc['description'])
54
+ action_desc['parameters'] = [
55
+ param for param in action_desc['parameters']
56
+ if param['name'] in action_desc['required']
57
+ ]
58
  plugin_descriptions.append(action_desc)
59
  return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
60
 
 
76
  parsers=[
77
  dict(type=PluginParser, template=PLUGIN_CN),
78
  dict(type=InterpreterParser, template=INTERPRETER_CN),
79
+ ]),
 
80
  aggregator: Dict = dict(type=InternLMToolAggregator),
81
  action_hooks: List = [dict(type=InternLMActionProcessor)],
82
+ finish_condition: Callable[
83
+ [AgentMessage],
84
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
85
  max_turn: int = 4,
86
  **kwargs,
87
  ):
88
+ agent = dict(
89
+ type=self._INTERNAL_AGENT_CLS,
90
  llm=llm,
91
  template=template,
92
  output_format=output_format,
 
94
  aggregator=aggregator,
95
  hooks=kwargs.pop('hooks', None),
96
  )
97
+ self.agent = create_object(agent)
98
+ self.plugin_executor = plugins and ActionExecutor(
99
+ plugins, hooks=action_hooks)
100
+ self.interpreter_executor = interpreter and ActionExecutor(
101
+ interpreter, hooks=action_hooks)
102
  if not (self.plugin_executor or self.interpreter_executor):
103
  warnings.warn(
104
  'Neither plugin nor interpreter executor is initialized. '
105
+ 'An exception will be thrown when the agent call a tool.')
 
106
  self.finish_condition = finish_condition
107
  self.max_turn = max_turn
108
  super().__init__(**kwargs)
109
 
110
  def forward(self, message: AgentMessage, session_id=0, **kwargs):
111
+ if isinstance(message, str):
112
+ message = AgentMessage(sender='user', content=message)
113
  for _ in range(self.max_turn):
114
  message = self.agent(message, session_id=session_id, **kwargs)
115
  assert isinstance(message.formatted, dict)
 
127
  steps, tool_type = [], None
128
  for msg in self.agent.memory.get_memory(session_id):
129
  if msg.sender == self.agent.name:
130
+ steps.append(
131
+ dict(role='thought', content=msg.formatted['thought']))
132
  if msg.formatted['tool_type']:
133
  tool_type = msg.formatted['tool_type']
134
+ steps.append(
135
+ dict(
136
+ role='tool',
137
+ content=msg.formatted['action'],
138
+ name=tool_type))
139
  elif msg.sender != 'user':
140
  feedback = dict(role='environment', content=msg.content)
141
  if tool_type:
 
149
  def __init__(
150
  self,
151
  llm: Union[BaseLLM, Dict],
152
+ interpreter: dict = dict(
153
+ type=IPythonInteractive, timeout=20, max_out_len=8192),
154
  template: Union[str, dict, List[dict]] = None,
155
  memory: Dict = dict(type=Memory),
156
  output_format: Dict = dict(
157
  type=InterpreterParser,
158
+ template=
159
+ ('Integrate step-by-step reasoning and Python code to solve math problems '
160
+ 'using the following guidelines:\n'
161
+ '- Analyze the question and write jupyter code to solve the problem;\n'
162
+ r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
163
+ 'units. \n')),
 
 
164
  aggregator: Dict = dict(type=InternLMToolAggregator),
165
  action_hooks: List = [dict(type=InternLMActionProcessor)],
166
+ finish_condition: Callable[
167
+ [AgentMessage],
168
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
169
  max_turn: int = 6,
170
  **kwargs,
171
  ):
 
180
  action_hooks=action_hooks,
181
  finish_condition=finish_condition,
182
  max_turn=max_turn,
183
+ **kwargs)
 
184
 
185
 
186
  class AsyncAgentForInternLM(AsyncAgent):
 
200
  parsers=[
201
  dict(type=PluginParser, template=PLUGIN_CN),
202
  dict(type=InterpreterParser, template=INTERPRETER_CN),
203
+ ]),
 
204
  aggregator: Dict = dict(type=InternLMToolAggregator),
205
  action_hooks: List = [dict(type=InternLMActionProcessor)],
206
+ finish_condition: Callable[
207
+ [AgentMessage],
208
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
209
  max_turn: int = 4,
210
  **kwargs,
211
  ):
212
+ agent = dict(
213
+ type=self._INTERNAL_AGENT_CLS,
214
  llm=llm,
215
  template=template,
216
  output_format=output_format,
 
218
  aggregator=aggregator,
219
  hooks=kwargs.pop('hooks', None),
220
  )
221
+ self.agent = create_object(agent)
222
+ self.plugin_executor = plugins and AsyncActionExecutor(
223
+ plugins, hooks=action_hooks)
224
+ self.interpreter_executor = interpreter and AsyncActionExecutor(
225
+ interpreter, hooks=action_hooks)
226
  if not (self.plugin_executor or self.interpreter_executor):
227
  warnings.warn(
228
  'Neither plugin nor interpreter executor is initialized. '
229
+ 'An exception will be thrown when the agent call a tool.')
 
230
  self.finish_condition = finish_condition
231
  self.max_turn = max_turn
232
  super().__init__(**kwargs)
233
 
234
  async def forward(self, message: AgentMessage, session_id=0, **kwargs):
235
+ if isinstance(message, str):
236
+ message = AgentMessage(sender='user', content=message)
237
  for _ in range(self.max_turn):
238
+ message = await self.agent(
239
+ message, session_id=session_id, **kwargs)
240
  assert isinstance(message.formatted, dict)
241
  if self.finish_condition(message):
242
  return message
 
252
  steps, tool_type = [], None
253
  for msg in self.agent.memory.get_memory(session_id):
254
  if msg.sender == self.agent.name:
255
+ steps.append(
256
+ dict(role='thought', content=msg.formatted['thought']))
257
  if msg.formatted['tool_type']:
258
  tool_type = msg.formatted['tool_type']
259
+ steps.append(
260
+ dict(
261
+ role='tool',
262
+ content=msg.formatted['action'],
263
+ name=tool_type))
264
  elif msg.sender != 'user':
265
  feedback = dict(role='environment', content=msg.content)
266
  if tool_type:
 
279
  memory: Dict = dict(type=Memory),
280
  output_format: Dict = dict(
281
  type=InterpreterParser,
282
+ template=
283
+ ('Integrate step-by-step reasoning and Python code to solve math problems '
284
+ 'using the following guidelines:\n'
285
+ '- Analyze the question and write jupyter code to solve the problem;\n'
286
+ r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
287
+ 'units. \n')),
 
 
288
  aggregator: Dict = dict(type=InternLMToolAggregator),
289
  action_hooks: List = [dict(type=InternLMActionProcessor)],
290
+ finish_condition: Callable[
291
+ [AgentMessage],
292
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
293
  max_turn: int = 6,
294
  **kwargs,
295
  ):
 
304
  action_hooks=action_hooks,
305
  finish_condition=finish_condition,
306
  max_turn=max_turn,
307
+ **kwargs)
 
308
 
309
  async def forward(self, message: AgentMessage, session_id=0, **kwargs):
310
  try:
311
  return await super().forward(message, session_id, **kwargs)
312
  finally:
313
+ interpreter = next(
314
+ iter(self.interpreter_executor.actions.values()))
315
  if interpreter.name == 'AsyncIPythonInterpreter':
316
  await interpreter.close_session(session_id)
lagent/distributed/http_serve/api_server.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import subprocess
4
  import sys
5
  import time
6
- import threading
7
 
8
  import aiohttp
9
  import requests
@@ -78,21 +77,14 @@ class HTTPAgentServer(HTTPAgentClient):
78
  stderr=subprocess.STDOUT,
79
  text=True)
80
 
81
- self.service_started = False
82
-
83
- def log_output(stream):
84
- if stream is not None:
85
- for line in iter(stream.readline, ''):
86
- print(line, end='')
87
- if 'Uvicorn running on' in line:
88
- self.service_started = True
89
-
90
- # Start log output thread
91
- threading.Thread(target=log_output, args=(self.process.stdout,), daemon=True).start()
92
- threading.Thread(target=log_output, args=(self.process.stderr,), daemon=True).start()
93
-
94
- # Waiting for the service to start
95
- while not self.service_started:
96
  time.sleep(0.1)
97
 
98
  def shutdown(self):
 
3
  import subprocess
4
  import sys
5
  import time
 
6
 
7
  import aiohttp
8
  import requests
 
77
  stderr=subprocess.STDOUT,
78
  text=True)
79
 
80
+ while True:
81
+ output = self.process.stdout.readline()
82
+ if not output: # 如果读到 EOF,跳出循环
83
+ break
84
+ sys.stdout.write(output) # 打印到标准输出
85
+ sys.stdout.flush()
86
+ if 'Uvicorn running on' in output: # 根据实际输出调整
87
+ break
 
 
 
 
 
 
 
88
  time.sleep(0.1)
89
 
90
  def shutdown(self):
lagent/hooks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (387 Bytes). View file
 
lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc ADDED
Binary file (2.29 kB). View file
 
lagent/hooks/__pycache__/hook.cpython-310.pyc ADDED
Binary file (1.53 kB). View file
 
lagent/hooks/__pycache__/logger.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
lagent/hooks/logger.py CHANGED
@@ -1,4 +1,5 @@
1
  import random
 
2
 
3
  from termcolor import COLORS, colored
4
 
@@ -7,10 +8,10 @@ from .hook import Hook
7
 
8
 
9
  class MessageLogger(Hook):
10
- def __init__(self, name: str = 'lagent', add_file_handler: bool = False):
 
11
  self.logger = get_logger(
12
- name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s', add_file_handler=add_file_handler
13
- )
14
  self.sender2color = {}
15
 
16
  def before_agent(self, agent, messages, session_id):
@@ -28,5 +29,9 @@ class MessageLogger(Hook):
28
 
29
  def _process_message(self, message, session_id):
30
  sender = message.sender
31
- color = self.sender2color.setdefault(sender, random.choice(list(COLORS)))
32
- self.logger.info(colored(f'session id: {session_id}, message sender: {sender}\n' f'{message.content}', color))
 
 
 
 
 
1
  import random
2
+ from typing import Optional
3
 
4
  from termcolor import COLORS, colored
5
 
 
8
 
9
 
10
  class MessageLogger(Hook):
11
+
12
+ def __init__(self, name: str = 'lagent'):
13
  self.logger = get_logger(
14
+ name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s')
 
15
  self.sender2color = {}
16
 
17
  def before_agent(self, agent, messages, session_id):
 
29
 
30
  def _process_message(self, message, session_id):
31
  sender = message.sender
32
+ color = self.sender2color.setdefault(sender,
33
+ random.choice(list(COLORS)))
34
+ self.logger.info(
35
+ colored(
36
+ f'session id: {session_id}, message sender: {sender}\n'
37
+ f'{message.content}', color))
lagent/llms/__init__.py CHANGED
@@ -1,15 +1,9 @@
1
- from .anthropic_llm import AsyncClaudeAPI, ClaudeAPI
2
  from .base_api import AsyncBaseAPILLM, BaseAPILLM
3
  from .base_llm import AsyncBaseLLM, BaseLLM
4
  from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
5
- from .lmdeploy_wrapper import (
6
- AsyncLMDeployClient,
7
- AsyncLMDeployPipeline,
8
- AsyncLMDeployServer,
9
- LMDeployClient,
10
- LMDeployPipeline,
11
- LMDeployServer,
12
- )
13
  from .meta_template import INTERNLM2_META
14
  from .openai import GPTAPI, AsyncGPTAPI
15
  from .sensenova import SensenovaAPI
@@ -35,6 +29,4 @@ __all__ = [
35
  'VllmModel',
36
  'AsyncVllmModel',
37
  'SensenovaAPI',
38
- 'AsyncClaudeAPI',
39
- 'ClaudeAPI',
40
  ]
 
 
1
  from .base_api import AsyncBaseAPILLM, BaseAPILLM
2
  from .base_llm import AsyncBaseLLM, BaseLLM
3
  from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
4
+ from .lmdeploy_wrapper import (AsyncLMDeployClient, AsyncLMDeployPipeline,
5
+ AsyncLMDeployServer, LMDeployClient,
6
+ LMDeployPipeline, LMDeployServer)
 
 
 
 
 
7
  from .meta_template import INTERNLM2_META
8
  from .openai import GPTAPI, AsyncGPTAPI
9
  from .sensenova import SensenovaAPI
 
29
  'VllmModel',
30
  'AsyncVllmModel',
31
  'SensenovaAPI',
 
 
32
  ]