Spaces:
Sleeping
Sleeping
Commit
·
a6a9bfa
1
Parent(s):
bc74a8c
Add files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/dependabot.yml +0 -12
- .pre-commit-config.yaml +6 -8
- app.py +2 -2
- examples/agent_api_web_demo.py +4 -2
- lagent.egg-info/PKG-INFO +608 -0
- lagent.egg-info/SOURCES.txt +71 -0
- lagent.egg-info/dependency_links.txt +1 -0
- lagent.egg-info/requires.txt +59 -0
- lagent.egg-info/top_level.txt +1 -0
- lagent/__pycache__/__init__.cpython-310.pyc +0 -0
- lagent/__pycache__/schema.cpython-310.pyc +0 -0
- lagent/__pycache__/version.cpython-310.pyc +0 -0
- lagent/actions/__init__.py +20 -31
- lagent/actions/__pycache__/__init__.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/action_executor.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/arxiv_search.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/base_action.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/bing_map.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/builtin_actions.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/google_search.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/ipython_manager.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/parser.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/ppt.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/python_interpreter.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/weather_query.cpython-310.pyc +0 -0
- lagent/actions/__pycache__/web_browser.cpython-310.pyc +0 -0
- lagent/actions/base_action.py +55 -42
- lagent/actions/weather_query.py +71 -0
- lagent/actions/web_browser.py +283 -232
- lagent/agents/__init__.py +4 -28
- lagent/agents/__pycache__/__init__.cpython-310.pyc +0 -0
- lagent/agents/__pycache__/agent.cpython-310.pyc +0 -0
- lagent/agents/__pycache__/react.cpython-310.pyc +0 -0
- lagent/agents/__pycache__/stream.cpython-310.pyc +0 -0
- lagent/agents/agent.py +117 -187
- lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc +0 -0
- lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc +0 -0
- lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc +0 -0
- lagent/agents/react.py +73 -76
- lagent/agents/stream.py +94 -65
- lagent/distributed/http_serve/api_server.py +8 -16
- lagent/hooks/__pycache__/__init__.cpython-310.pyc +0 -0
- lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc +0 -0
- lagent/hooks/__pycache__/hook.cpython-310.pyc +0 -0
- lagent/hooks/__pycache__/logger.cpython-310.pyc +0 -0
- lagent/hooks/logger.py +10 -5
- lagent/llms/__init__.py +3 -11
.github/dependabot.yml
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
# To get started with Dependabot version updates, you'll need to specify which
|
2 |
-
# package ecosystems to update and where the package manifests are located.
|
3 |
-
# Please see the documentation for all configuration options:
|
4 |
-
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
|
5 |
-
|
6 |
-
version: 2
|
7 |
-
updates:
|
8 |
-
- package-ecosystem: "" # See documentation for possible values
|
9 |
-
directory: "/" # Location of package manifests
|
10 |
-
schedule:
|
11 |
-
interval: "weekly"
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.pre-commit-config.yaml
CHANGED
@@ -1,22 +1,20 @@
|
|
1 |
exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
|
2 |
repos:
|
3 |
- repo: https://github.com/PyCQA/flake8
|
4 |
-
rev: 7.
|
5 |
hooks:
|
6 |
- id: flake8
|
7 |
- repo: https://github.com/PyCQA/isort
|
8 |
rev: 5.13.2
|
9 |
hooks:
|
10 |
- id: isort
|
11 |
-
args: ["--profile", "black", "--filter-files", "--line-width", "119"]
|
12 |
- repo: https://github.com/psf/black
|
13 |
-
rev:
|
14 |
hooks:
|
15 |
- id: black
|
16 |
args: ["--line-length", "119", "--skip-string-normalization"]
|
17 |
-
|
18 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
19 |
-
rev:
|
20 |
hooks:
|
21 |
- id: trailing-whitespace
|
22 |
- id: check-yaml
|
@@ -29,7 +27,7 @@ repos:
|
|
29 |
- id: mixed-line-ending
|
30 |
args: ["--fix=lf"]
|
31 |
- repo: https://github.com/executablebooks/mdformat
|
32 |
-
rev: 0.7.
|
33 |
hooks:
|
34 |
- id: mdformat
|
35 |
args: ["--number"]
|
@@ -38,11 +36,11 @@ repos:
|
|
38 |
- mdformat_frontmatter
|
39 |
- linkify-it-py
|
40 |
- repo: https://github.com/codespell-project/codespell
|
41 |
-
rev: v2.
|
42 |
hooks:
|
43 |
- id: codespell
|
44 |
- repo: https://github.com/asottile/pyupgrade
|
45 |
-
rev: v3.
|
46 |
hooks:
|
47 |
- id: pyupgrade
|
48 |
args: ["--py36-plus"]
|
|
|
1 |
exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
|
2 |
repos:
|
3 |
- repo: https://github.com/PyCQA/flake8
|
4 |
+
rev: 7.0.0
|
5 |
hooks:
|
6 |
- id: flake8
|
7 |
- repo: https://github.com/PyCQA/isort
|
8 |
rev: 5.13.2
|
9 |
hooks:
|
10 |
- id: isort
|
|
|
11 |
- repo: https://github.com/psf/black
|
12 |
+
rev: 22.8.0
|
13 |
hooks:
|
14 |
- id: black
|
15 |
args: ["--line-length", "119", "--skip-string-normalization"]
|
|
|
16 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
17 |
+
rev: v4.5.0
|
18 |
hooks:
|
19 |
- id: trailing-whitespace
|
20 |
- id: check-yaml
|
|
|
27 |
- id: mixed-line-ending
|
28 |
args: ["--fix=lf"]
|
29 |
- repo: https://github.com/executablebooks/mdformat
|
30 |
+
rev: 0.7.17
|
31 |
hooks:
|
32 |
- id: mdformat
|
33 |
args: ["--number"]
|
|
|
36 |
- mdformat_frontmatter
|
37 |
- linkify-it-py
|
38 |
- repo: https://github.com/codespell-project/codespell
|
39 |
+
rev: v2.2.6
|
40 |
hooks:
|
41 |
- id: codespell
|
42 |
- repo: https://github.com/asottile/pyupgrade
|
43 |
+
rev: v3.15.0
|
44 |
hooks:
|
45 |
- id: pyupgrade
|
46 |
args: ["--py36-plus"]
|
app.py
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
Author: Highthoughts cht7613@gmail.com
|
3 |
Date: 2025-01-30 11:02:01
|
4 |
LastEditors: Highthoughts cht7613@gmail.com
|
5 |
-
LastEditTime: 2025-01-30 11:
|
6 |
-
FilePath: \
|
7 |
Description:
|
8 |
|
9 |
Copyright (c) 2025 by Cuihaitao, All Rights Reserved.
|
|
|
2 |
Author: Highthoughts cht7613@gmail.com
|
3 |
Date: 2025-01-30 11:02:01
|
4 |
LastEditors: Highthoughts cht7613@gmail.com
|
5 |
+
LastEditTime: 2025-01-30 11:41:16
|
6 |
+
FilePath: \AgentTest\app.py
|
7 |
Description:
|
8 |
|
9 |
Copyright (c) 2025 by Cuihaitao, All Rights Reserved.
|
examples/agent_api_web_demo.py
CHANGED
@@ -2,7 +2,8 @@ import copy
|
|
2 |
import os
|
3 |
from typing import List
|
4 |
import streamlit as st
|
5 |
-
from lagent.actions import ArxivSearch
|
|
|
6 |
from lagent.prompts.parsers import PluginParser
|
7 |
from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt
|
8 |
from lagent.llms import GPTAPI
|
@@ -17,6 +18,7 @@ class SessionState:
|
|
17 |
# 初始化插件列表
|
18 |
action_list = [
|
19 |
ArxivSearch(),
|
|
|
20 |
]
|
21 |
st.session_state['plugin_map'] = {action.name: action for action in action_list}
|
22 |
st.session_state['model_map'] = {} # 存储模型实例
|
@@ -50,7 +52,7 @@ class StreamlitUI:
|
|
50 |
# page_title='lagent-web',
|
51 |
# page_icon='./docs/imgs/lagent_icon.png'
|
52 |
# )
|
53 |
-
|
54 |
|
55 |
def setup_sidebar(self):
|
56 |
"""设置侧边栏,选择模型和插件。"""
|
|
|
2 |
import os
|
3 |
from typing import List
|
4 |
import streamlit as st
|
5 |
+
# from lagent.actions import ArxivSearch
|
6 |
+
from lagent.actions import ArxivSearch, WeatherQuery
|
7 |
from lagent.prompts.parsers import PluginParser
|
8 |
from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt
|
9 |
from lagent.llms import GPTAPI
|
|
|
18 |
# 初始化插件列表
|
19 |
action_list = [
|
20 |
ArxivSearch(),
|
21 |
+
WeatherQuery(),
|
22 |
]
|
23 |
st.session_state['plugin_map'] = {action.name: action for action in action_list}
|
24 |
st.session_state['model_map'] = {} # 存储模型实例
|
|
|
52 |
# page_title='lagent-web',
|
53 |
# page_icon='./docs/imgs/lagent_icon.png'
|
54 |
# )
|
55 |
+
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
56 |
|
57 |
def setup_sidebar(self):
|
58 |
"""设置侧边栏,选择模型和插件。"""
|
lagent.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.2
|
2 |
+
Name: lagent
|
3 |
+
Version: 0.5.0rc1
|
4 |
+
Summary: A lightweight framework for building LLM-based agents
|
5 |
+
Home-page: https://github.com/InternLM/lagent
|
6 |
+
License: Apache 2.0
|
7 |
+
Keywords: artificial general intelligence,agent,agi,llm
|
8 |
+
Description-Content-Type: text/markdown
|
9 |
+
License-File: LICENSE
|
10 |
+
Requires-Dist: aiohttp
|
11 |
+
Requires-Dist: arxiv
|
12 |
+
Requires-Dist: asyncache
|
13 |
+
Requires-Dist: asyncer
|
14 |
+
Requires-Dist: distro
|
15 |
+
Requires-Dist: duckduckgo_search==5.3.1b1
|
16 |
+
Requires-Dist: filelock
|
17 |
+
Requires-Dist: func_timeout
|
18 |
+
Requires-Dist: griffe<1.0
|
19 |
+
Requires-Dist: json5
|
20 |
+
Requires-Dist: jsonschema
|
21 |
+
Requires-Dist: jupyter==1.0.0
|
22 |
+
Requires-Dist: jupyter_client==8.6.2
|
23 |
+
Requires-Dist: jupyter_core==5.7.2
|
24 |
+
Requires-Dist: pydantic==2.6.4
|
25 |
+
Requires-Dist: requests
|
26 |
+
Requires-Dist: termcolor
|
27 |
+
Requires-Dist: tiktoken
|
28 |
+
Requires-Dist: timeout-decorator
|
29 |
+
Requires-Dist: typing-extensions
|
30 |
+
Provides-Extra: all
|
31 |
+
Requires-Dist: google-search-results; extra == "all"
|
32 |
+
Requires-Dist: lmdeploy>=0.2.5; extra == "all"
|
33 |
+
Requires-Dist: pillow; extra == "all"
|
34 |
+
Requires-Dist: python-pptx; extra == "all"
|
35 |
+
Requires-Dist: timeout_decorator; extra == "all"
|
36 |
+
Requires-Dist: torch; extra == "all"
|
37 |
+
Requires-Dist: transformers<=4.40,>=4.34; extra == "all"
|
38 |
+
Requires-Dist: vllm>=0.3.3; extra == "all"
|
39 |
+
Requires-Dist: aiohttp; extra == "all"
|
40 |
+
Requires-Dist: arxiv; extra == "all"
|
41 |
+
Requires-Dist: asyncache; extra == "all"
|
42 |
+
Requires-Dist: asyncer; extra == "all"
|
43 |
+
Requires-Dist: distro; extra == "all"
|
44 |
+
Requires-Dist: duckduckgo_search==5.3.1b1; extra == "all"
|
45 |
+
Requires-Dist: filelock; extra == "all"
|
46 |
+
Requires-Dist: func_timeout; extra == "all"
|
47 |
+
Requires-Dist: griffe<1.0; extra == "all"
|
48 |
+
Requires-Dist: json5; extra == "all"
|
49 |
+
Requires-Dist: jsonschema; extra == "all"
|
50 |
+
Requires-Dist: jupyter==1.0.0; extra == "all"
|
51 |
+
Requires-Dist: jupyter_client==8.6.2; extra == "all"
|
52 |
+
Requires-Dist: jupyter_core==5.7.2; extra == "all"
|
53 |
+
Requires-Dist: pydantic==2.6.4; extra == "all"
|
54 |
+
Requires-Dist: requests; extra == "all"
|
55 |
+
Requires-Dist: termcolor; extra == "all"
|
56 |
+
Requires-Dist: tiktoken; extra == "all"
|
57 |
+
Requires-Dist: timeout-decorator; extra == "all"
|
58 |
+
Requires-Dist: typing-extensions; extra == "all"
|
59 |
+
Provides-Extra: optional
|
60 |
+
Requires-Dist: google-search-results; extra == "optional"
|
61 |
+
Requires-Dist: lmdeploy>=0.2.5; extra == "optional"
|
62 |
+
Requires-Dist: pillow; extra == "optional"
|
63 |
+
Requires-Dist: python-pptx; extra == "optional"
|
64 |
+
Requires-Dist: timeout_decorator; extra == "optional"
|
65 |
+
Requires-Dist: torch; extra == "optional"
|
66 |
+
Requires-Dist: transformers<=4.40,>=4.34; extra == "optional"
|
67 |
+
Requires-Dist: vllm>=0.3.3; extra == "optional"
|
68 |
+
Dynamic: description
|
69 |
+
Dynamic: description-content-type
|
70 |
+
Dynamic: home-page
|
71 |
+
Dynamic: keywords
|
72 |
+
Dynamic: license
|
73 |
+
Dynamic: provides-extra
|
74 |
+
Dynamic: requires-dist
|
75 |
+
Dynamic: summary
|
76 |
+
|
77 |
+
<div id="top"></div>
|
78 |
+
<div align="center">
|
79 |
+
<img src="docs/imgs/lagent_logo.png" width="450"/>
|
80 |
+
|
81 |
+
[](https://lagent.readthedocs.io/en/latest/)
|
82 |
+
[](https://pypi.org/project/lagent)
|
83 |
+
[](https://github.com/InternLM/lagent/tree/main/LICENSE)
|
84 |
+
[](https://github.com/InternLM/lagent/issues)
|
85 |
+
[](https://github.com/InternLM/lagent/issues)
|
86 |
+

|
87 |
+

|
88 |
+

|
89 |
+

|
90 |
+
|
91 |
+
</div>
|
92 |
+
|
93 |
+
<p align="center">
|
94 |
+
👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">𝕏 (Twitter)</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a>
|
95 |
+
</p>
|
96 |
+
|
97 |
+
## Installation
|
98 |
+
|
99 |
+
Install from source:
|
100 |
+
|
101 |
+
```bash
|
102 |
+
git clone https://github.com/InternLM/lagent.git
|
103 |
+
cd lagent
|
104 |
+
pip install -e .
|
105 |
+
```
|
106 |
+
|
107 |
+
## Usage
|
108 |
+
|
109 |
+
Lagent is inspired by the design philosophy of PyTorch. We expect that the analogy of neural network layers will make the workflow clearer and more intuitive, so users only need to focus on creating layers and defining message passing between them in a Pythonic way. This is a simple tutorial to get you quickly started with building multi-agent applications.
|
110 |
+
|
111 |
+
### Models as Agents
|
112 |
+
|
113 |
+
Agents use `AgentMessage` for communication.
|
114 |
+
|
115 |
+
```python
|
116 |
+
from typing import Dict, List
|
117 |
+
from lagent.agents import Agent
|
118 |
+
from lagent.schema import AgentMessage
|
119 |
+
from lagent.llms import VllmModel, INTERNLM2_META
|
120 |
+
|
121 |
+
llm = VllmModel(
|
122 |
+
path='Qwen/Qwen2-7B-Instruct',
|
123 |
+
meta_template=INTERNLM2_META,
|
124 |
+
tp=1,
|
125 |
+
top_k=1,
|
126 |
+
temperature=1.0,
|
127 |
+
stop_words=['<|im_end|>'],
|
128 |
+
max_new_tokens=1024,
|
129 |
+
)
|
130 |
+
system_prompt = '你的回答只能从“典”、“孝”、“急”三��字中选一个。'
|
131 |
+
agent = Agent(llm, system_prompt)
|
132 |
+
|
133 |
+
user_msg = AgentMessage(sender='user', content='今天天气情况')
|
134 |
+
bot_msg = agent(user_msg)
|
135 |
+
print(bot_msg)
|
136 |
+
```
|
137 |
+
|
138 |
+
```
|
139 |
+
content='急' sender='Agent' formatted=None extra_info=None type=None receiver=None stream_state=<AgentStatusCode.END: 0>
|
140 |
+
```
|
141 |
+
|
142 |
+
### Memory as State
|
143 |
+
|
144 |
+
Both input and output messages will be added to the memory of `Agent` in each forward pass. This is performed in `__call__` rather than `forward`. See the following pseudo code
|
145 |
+
|
146 |
+
```python
|
147 |
+
def __call__(self, *message):
|
148 |
+
message = pre_hooks(message)
|
149 |
+
add_memory(message)
|
150 |
+
message = self.forward(*message)
|
151 |
+
add_memory(message)
|
152 |
+
message = post_hooks(message)
|
153 |
+
return message
|
154 |
+
```
|
155 |
+
|
156 |
+
Inspect the memory in two ways
|
157 |
+
|
158 |
+
```python
|
159 |
+
memory: List[AgentMessage] = agent.memory.get_memory()
|
160 |
+
print(memory)
|
161 |
+
print('-' * 120)
|
162 |
+
dumped_memory: Dict[str, List[dict]] = agent.state_dict()
|
163 |
+
print(dumped_memory['memory'])
|
164 |
+
```
|
165 |
+
|
166 |
+
```
|
167 |
+
[AgentMessage(content='今天天气情况', sender='user', formatted=None, extra_info=None, type=None, receiver=None, stream_state=<AgentStatusCode.END: 0>), AgentMessage(content='急', sender='Agent', formatted=None, extra_info=None, type=None, receiver=None, stream_state=<AgentStatusCode.END: 0>)]
|
168 |
+
------------------------------------------------------------------------------------------------------------------------
|
169 |
+
[{'content': '今天天气情况', 'sender': 'user', 'formatted': None, 'extra_info': None, 'type': None, 'receiver': None, 'stream_state': <AgentStatusCode.END: 0>}, {'content': '急', 'sender': 'Agent', 'formatted': None, 'extra_info': None, 'type': None, 'receiver': None, 'stream_state': <AgentStatusCode.END: 0>}]
|
170 |
+
```
|
171 |
+
|
172 |
+
Clear the memory of this session(`session_id=0` by default):
|
173 |
+
|
174 |
+
```python
|
175 |
+
agent.memory.reset()
|
176 |
+
```
|
177 |
+
|
178 |
+
### Custom Message Aggregation
|
179 |
+
|
180 |
+
`DefaultAggregator` is called under the hood to assemble and convert `AgentMessage` to OpenAI message format.
|
181 |
+
|
182 |
+
```python
|
183 |
+
def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]:
|
184 |
+
formatted_messages = self.aggregator.aggregate(
|
185 |
+
self.memory.get(session_id),
|
186 |
+
self.name,
|
187 |
+
self.output_format,
|
188 |
+
self.template,
|
189 |
+
)
|
190 |
+
llm_response = self.llm.chat(formatted_messages, **kwargs)
|
191 |
+
...
|
192 |
+
```
|
193 |
+
|
194 |
+
Implement a simple aggregator that can receive few-shots
|
195 |
+
|
196 |
+
```python
|
197 |
+
from typing import List, Union
|
198 |
+
from lagent.memory import Memory
|
199 |
+
from lagent.prompts import StrParser
|
200 |
+
from lagent.agents.aggregator import DefaultAggregator
|
201 |
+
|
202 |
+
class FewshotAggregator(DefaultAggregator):
|
203 |
+
def __init__(self, few_shot: List[dict] = None):
|
204 |
+
self.few_shot = few_shot or []
|
205 |
+
|
206 |
+
def aggregate(self,
|
207 |
+
messages: Memory,
|
208 |
+
name: str,
|
209 |
+
parser: StrParser = None,
|
210 |
+
system_instruction: Union[str, dict, List[dict]] = None) -> List[dict]:
|
211 |
+
_message = []
|
212 |
+
if system_instruction:
|
213 |
+
_message.extend(
|
214 |
+
self.aggregate_system_intruction(system_instruction))
|
215 |
+
_message.extend(self.few_shot)
|
216 |
+
messages = messages.get_memory()
|
217 |
+
for message in messages:
|
218 |
+
if message.sender == name:
|
219 |
+
_message.append(
|
220 |
+
dict(role='assistant', content=str(message.content)))
|
221 |
+
else:
|
222 |
+
user_message = message.content
|
223 |
+
if len(_message) > 0 and _message[-1]['role'] == 'user':
|
224 |
+
_message[-1]['content'] += user_message
|
225 |
+
else:
|
226 |
+
_message.append(dict(role='user', content=user_message))
|
227 |
+
return _message
|
228 |
+
|
229 |
+
agent = Agent(
|
230 |
+
llm,
|
231 |
+
aggregator=FewshotAggregator(
|
232 |
+
[
|
233 |
+
{"role": "user", "content": "今天天气"},
|
234 |
+
{"role": "assistant", "content": "【晴】"},
|
235 |
+
]
|
236 |
+
)
|
237 |
+
)
|
238 |
+
user_msg = AgentMessage(sender='user', content='昨天天气')
|
239 |
+
bot_msg = agent(user_msg)
|
240 |
+
print(bot_msg)
|
241 |
+
```
|
242 |
+
|
243 |
+
```
|
244 |
+
content='【多云转晴,夜间有轻微降温】' sender='Agent' formatted=None extra_info=None type=None receiver=None stream_state=<AgentStatusCode.END: 0>
|
245 |
+
```
|
246 |
+
|
247 |
+
### Flexible Response Formatting
|
248 |
+
|
249 |
+
In `AgentMessage`, `formatted` is reserved to store information parsed by `output_format` from the model output.
|
250 |
+
|
251 |
+
```python
|
252 |
+
def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]:
|
253 |
+
...
|
254 |
+
llm_response = self.llm.chat(formatted_messages, **kwargs)
|
255 |
+
if self.output_format:
|
256 |
+
formatted_messages = self.output_format.parse_response(llm_response)
|
257 |
+
return AgentMessage(
|
258 |
+
sender=self.name,
|
259 |
+
content=llm_response,
|
260 |
+
formatted=formatted_messages,
|
261 |
+
)
|
262 |
+
...
|
263 |
+
```
|
264 |
+
|
265 |
+
Use a tool parser as follows
|
266 |
+
|
267 |
+
````python
|
268 |
+
from lagent.prompts.parsers import ToolParser
|
269 |
+
|
270 |
+
system_prompt = "逐步分析并编写Python代码解决以下问题。"
|
271 |
+
parser = ToolParser(tool_type='code interpreter', begin='```python\n', end='\n```\n')
|
272 |
+
llm.gen_params['stop_words'].append('\n```\n')
|
273 |
+
agent = Agent(llm, system_prompt, output_format=parser)
|
274 |
+
|
275 |
+
user_msg = AgentMessage(
|
276 |
+
sender='user',
|
277 |
+
content='Marie is thinking of a multiple of 63, while Jay is thinking of a '
|
278 |
+
'factor of 63. They happen to be thinking of the same number. There are '
|
279 |
+
'two possibilities for the number that each of them is thinking of, one '
|
280 |
+
'positive and one negative. Find the product of these two numbers.')
|
281 |
+
bot_msg = agent(user_msg)
|
282 |
+
print(bot_msg.model_dump_json(indent=4))
|
283 |
+
````
|
284 |
+
|
285 |
+
````
|
286 |
+
{
|
287 |
+
"content": "首先,我们需要找出63的所有正因数和负因数。63的正因数可以通过分解63的质因数来找出,即\\(63 = 3^2 \\times 7\\)。因此,63的正因数包括1, 3, 7, 9, 21, 和 63。对于负因数,我们只需将上述正因数乘以-1。\n\n接下来,我们需要找出与63的正因数相乘的结果为63的数,以及与63的负因数相乘的结果为63的数。这可以通过将63除以每个正因数和负因数来实现。\n\n最后,我们将找到的两个数相乘得到最终答案。\n\n下面是Python代码实现:\n\n```python\ndef find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)",
|
288 |
+
"sender": "Agent",
|
289 |
+
"formatted": {
|
290 |
+
"tool_type": "code interpreter",
|
291 |
+
"thought": "首先,我们需要找出63的所有正因数和负因数。63的正因数可以通过分解63的质因数来找出,即\\(63 = 3^2 \\times 7\\)。因此,63的正因数包括1, 3, 7, 9, 21, 和 63。对于负因数,我们只需将上述正因数乘以-1。\n\n接下来,我们需要找出与63的正因数相乘的结果为63的数,以及与63的负因数相乘的结果为63的数。这可以通过将63除以每个正因数和负因数来实现。\n\n最后,我们将找到的两个数相乘得到最终答案。\n\n下面是Python代码实现:\n\n",
|
292 |
+
"action": "def find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)",
|
293 |
+
"status": 1
|
294 |
+
},
|
295 |
+
"extra_info": null,
|
296 |
+
"type": null,
|
297 |
+
"receiver": null,
|
298 |
+
"stream_state": 0
|
299 |
+
}
|
300 |
+
````
|
301 |
+
|
302 |
+
### Consistency of Tool Calling
|
303 |
+
|
304 |
+
`ActionExecutor` uses the same communication data structure as `Agent`, but requires the content of input `AgentMessage` to be a dict containing:
|
305 |
+
|
306 |
+
- `name`: tool name, e.g. `'IPythonInterpreter'`, `'WebBrowser.search'`.
|
307 |
+
- `parameters`: keyword arguments of the tool API, e.g. `{'command': 'import math;math.sqrt(2)'}`, `{'query': ['recent progress in AI']}`.
|
308 |
+
|
309 |
+
You can register custom hooks for message conversion.
|
310 |
+
|
311 |
+
```python
|
312 |
+
from lagent.hooks import Hook
|
313 |
+
from lagent.schema import ActionReturn, ActionStatusCode, AgentMessage
|
314 |
+
from lagent.actions import ActionExecutor, IPythonInteractive
|
315 |
+
|
316 |
+
class CodeProcessor(Hook):
|
317 |
+
def before_action(self, executor, message, session_id):
|
318 |
+
message = message.copy(deep=True)
|
319 |
+
message.content = dict(
|
320 |
+
name='IPythonInteractive', parameters={'command': message.formatted['action']}
|
321 |
+
)
|
322 |
+
return message
|
323 |
+
|
324 |
+
def after_action(self, executor, message, session_id):
|
325 |
+
action_return = message.content
|
326 |
+
if isinstance(action_return, ActionReturn):
|
327 |
+
if action_return.state == ActionStatusCode.SUCCESS:
|
328 |
+
response = action_return.format_result()
|
329 |
+
else:
|
330 |
+
response = action_return.errmsg
|
331 |
+
else:
|
332 |
+
response = action_return
|
333 |
+
message.content = response
|
334 |
+
return message
|
335 |
+
|
336 |
+
executor = ActionExecutor(actions=[IPythonInteractive()], hooks=[CodeProcessor()])
|
337 |
+
bot_msg = AgentMessage(
|
338 |
+
sender='Agent',
|
339 |
+
content='首先,我们需要...',
|
340 |
+
formatted={
|
341 |
+
'tool_type': 'code interpreter',
|
342 |
+
'thought': '首先,我们需要...',
|
343 |
+
'action': 'def find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)',
|
344 |
+
'status': 1
|
345 |
+
})
|
346 |
+
executor_msg = executor(bot_msg)
|
347 |
+
print(executor_msg)
|
348 |
+
```
|
349 |
+
|
350 |
+
```
|
351 |
+
content='3969.0' sender='ActionExecutor' formatted=None extra_info=None type=None receiver=None stream_state=<AgentStatusCode.END: 0>
|
352 |
+
```
|
353 |
+
|
354 |
+
**For convenience, Lagent provides `InternLMActionProcessor` which is adapted to messages formatted by `ToolParser` as mentioned above.**
|
355 |
+
|
356 |
+
### Dual Interfaces
|
357 |
+
|
358 |
+
Lagent adopts dual interface design, where almost every component(LLMs, actions, action executors...) has the corresponding asynchronous variant by prefixing its identifier with 'Async'. It is recommended to use synchronous agents for debugging and asynchronous ones for large-scale inference to make the most of idle CPU and GPU resources.
|
359 |
+
|
360 |
+
However, make sure the internal consistency of agents, i.e. asynchronous agents should be equipped with asynchronous LLMs and asynchronous action executors that drive asynchronous tools.
|
361 |
+
|
362 |
+
```python
|
363 |
+
from lagent.llms import VllmModel, AsyncVllmModel, LMDeployPipeline, AsyncLMDeployPipeline
|
364 |
+
from lagent.actions import ActionExecutor, AsyncActionExecutor, WebBrowser, AsyncWebBrowser
|
365 |
+
from lagent.agents import Agent, AsyncAgent, AgentForInternLM, AsyncAgentForInternLM
|
366 |
+
```
|
367 |
+
|
368 |
+
______________________________________________________________________
|
369 |
+
|
370 |
+
## Practice
|
371 |
+
|
372 |
+
- **Try to implement `forward` instead of `__call__` of subclasses unless necessary.**
|
373 |
+
- **Always include the `session_id` argument explicitly, which is designed for isolation of memory, LLM requests and tool invocation(e.g. maintain multiple independent IPython environments) in concurrency.**
|
374 |
+
|
375 |
+
### Single Agent
|
376 |
+
|
377 |
+
Math agents that solve problems by programming
|
378 |
+
|
379 |
+
````python
|
380 |
+
from lagent.agents.aggregator import InternLMToolAggregator
|
381 |
+
|
382 |
+
class Coder(Agent):
|
383 |
+
def __init__(self, model_path, system_prompt, max_turn=3):
|
384 |
+
super().__init__()
|
385 |
+
llm = VllmModel(
|
386 |
+
path=model_path,
|
387 |
+
meta_template=INTERNLM2_META,
|
388 |
+
tp=1,
|
389 |
+
top_k=1,
|
390 |
+
temperature=1.0,
|
391 |
+
stop_words=['\n```\n', '<|im_end|>'],
|
392 |
+
max_new_tokens=1024,
|
393 |
+
)
|
394 |
+
self.agent = Agent(
|
395 |
+
llm,
|
396 |
+
system_prompt,
|
397 |
+
output_format=ToolParser(
|
398 |
+
tool_type='code interpreter', begin='```python\n', end='\n```\n'
|
399 |
+
),
|
400 |
+
# `InternLMToolAggregator` is adapted to `ToolParser` for aggregating
|
401 |
+
# messages with tool invocations and execution results
|
402 |
+
aggregator=InternLMToolAggregator(),
|
403 |
+
)
|
404 |
+
self.executor = ActionExecutor([IPythonInteractive()], hooks=[CodeProcessor()])
|
405 |
+
self.max_turn = max_turn
|
406 |
+
|
407 |
+
def forward(self, message: AgentMessage, session_id=0) -> AgentMessage:
|
408 |
+
for _ in range(self.max_turn):
|
409 |
+
message = self.agent(message, session_id=session_id)
|
410 |
+
if message.formatted['tool_type'] is None:
|
411 |
+
return message
|
412 |
+
message = self.executor(message, session_id=session_id)
|
413 |
+
return message
|
414 |
+
|
415 |
+
coder = Coder('Qwen/Qwen2-7B-Instruct', 'Solve the problem step by step with assistance of Python code')
|
416 |
+
query = AgentMessage(
|
417 |
+
sender='user',
|
418 |
+
content='Find the projection of $\\mathbf{a}$ onto $\\mathbf{b} = '
|
419 |
+
'\\begin{pmatrix} 1 \\\\ -3 \\end{pmatrix}$ if $\\mathbf{a} \\cdot \\mathbf{b} = 2.$'
|
420 |
+
)
|
421 |
+
answer = coder(query)
|
422 |
+
print(answer.content)
|
423 |
+
print('-' * 120)
|
424 |
+
for msg in coder.state_dict()['agent.memory']:
|
425 |
+
print('*' * 80)
|
426 |
+
print(f'{msg["sender"]}:\n\n{msg["content"]}')
|
427 |
+
````
|
428 |
+
|
429 |
+
### Multiple Agents
|
430 |
+
|
431 |
+
Asynchronous blogging agents that improve writing quality by self-refinement ([original AutoGen example](https://microsoft.github.io/autogen/0.2/docs/topics/prompting-and-reasoning/reflection/))
|
432 |
+
|
433 |
+
```python
|
434 |
+
import asyncio
|
435 |
+
import os
|
436 |
+
from lagent.llms import AsyncGPTAPI
|
437 |
+
from lagent.agents import AsyncAgent
|
438 |
+
os.environ['OPENAI_API_KEY'] = 'YOUR_API_KEY'
|
439 |
+
|
440 |
+
class PrefixedMessageHook(Hook):
|
441 |
+
def __init__(self, prefix: str, senders: list = None):
|
442 |
+
self.prefix = prefix
|
443 |
+
self.senders = senders or []
|
444 |
+
|
445 |
+
def before_agent(self, agent, messages, session_id):
|
446 |
+
for message in messages:
|
447 |
+
if message.sender in self.senders:
|
448 |
+
message.content = self.prefix + message.content
|
449 |
+
|
450 |
+
class AsyncBlogger(AsyncAgent):
|
451 |
+
def __init__(self, model_path, writer_prompt, critic_prompt, critic_prefix='', max_turn=3):
|
452 |
+
super().__init__()
|
453 |
+
llm = AsyncGPTAPI(model_type=model_path, retry=5, max_new_tokens=2048)
|
454 |
+
self.writer = AsyncAgent(llm, writer_prompt, name='writer')
|
455 |
+
self.critic = AsyncAgent(
|
456 |
+
llm, critic_prompt, name='critic', hooks=[PrefixedMessageHook(critic_prefix, ['writer'])]
|
457 |
+
)
|
458 |
+
self.max_turn = max_turn
|
459 |
+
|
460 |
+
async def forward(self, message: AgentMessage, session_id=0) -> AgentMessage:
|
461 |
+
for _ in range(self.max_turn):
|
462 |
+
message = await self.writer(message, session_id=session_id)
|
463 |
+
message = await self.critic(message, session_id=session_id)
|
464 |
+
return await self.writer(message, session_id=session_id)
|
465 |
+
|
466 |
+
blogger = AsyncBlogger(
|
467 |
+
'gpt-4o-2024-05-13',
|
468 |
+
writer_prompt="You are an writing assistant tasked to write engaging blogpost. You try to generate the best blogpost possible for the user's request. "
|
469 |
+
"If the user provides critique, then respond with a revised version of your previous attempts",
|
470 |
+
critic_prompt="Generate critique and recommendations on the writing. Provide detailed recommendations, including requests for length, depth, style, etc..",
|
471 |
+
critic_prefix='Reflect and provide critique on the following writing. \n\n',
|
472 |
+
)
|
473 |
+
user_prompt = (
|
474 |
+
"Write an engaging blogpost on the recent updates in {topic}. "
|
475 |
+
"The blogpost should be engaging and understandable for general audience. "
|
476 |
+
"Should have more than 3 paragraphes but no longer than 1000 words.")
|
477 |
+
bot_msgs = asyncio.get_event_loop().run_until_complete(
|
478 |
+
asyncio.gather(
|
479 |
+
*[
|
480 |
+
blogger(AgentMessage(sender='user', content=user_prompt.format(topic=topic)), session_id=i)
|
481 |
+
for i, topic in enumerate(['AI', 'Biotechnology', 'New Energy', 'Video Games', 'Pop Music'])
|
482 |
+
]
|
483 |
+
)
|
484 |
+
)
|
485 |
+
print(bot_msgs[0].content)
|
486 |
+
print('-' * 120)
|
487 |
+
for msg in blogger.state_dict(session_id=0)['writer.memory']:
|
488 |
+
print('*' * 80)
|
489 |
+
print(f'{msg["sender"]}:\n\n{msg["content"]}')
|
490 |
+
print('-' * 120)
|
491 |
+
for msg in blogger.state_dict(session_id=0)['critic.memory']:
|
492 |
+
print('*' * 80)
|
493 |
+
print(f'{msg["sender"]}:\n\n{msg["content"]}')
|
494 |
+
```
|
495 |
+
|
496 |
+
A multi-agent workflow that performs information retrieval, data collection and chart plotting ([original LangGraph example](https://vijaykumarkartha.medium.com/multiple-ai-agents-creating-multi-agent-workflows-using-langgraph-and-langchain-0587406ec4e6))
|
497 |
+
|
498 |
+
<div align="center">
|
499 |
+
<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*ffzadZCKXJT7n4JaRVFvcQ.jpeg" width="850" />
|
500 |
+
</div>
|
501 |
+
|
502 |
+
````python
|
503 |
+
import json
|
504 |
+
from lagent.actions import IPythonInterpreter, WebBrowser, ActionExecutor
|
505 |
+
from lagent.agents.stream import get_plugin_prompt
|
506 |
+
from lagent.llms import GPTAPI
|
507 |
+
from lagent.hooks import InternLMActionProcessor
|
508 |
+
|
509 |
+
TOOL_TEMPLATE = (
|
510 |
+
"You are a helpful AI assistant, collaborating with other assistants. Use the provided tools to progress"
|
511 |
+
" towards answering the question. If you are unable to fully answer, that's OK, another assistant with"
|
512 |
+
" different tools will help where you left off. Execute what you can to make progress. If you or any of"
|
513 |
+
" the other assistants have the final answer or deliverable, prefix your response with {finish_pattern}"
|
514 |
+
" so the team knows to stop. You have access to the following tools:\n{tool_description}\nPlease provide"
|
515 |
+
" your thought process when you need to use a tool, followed by the call statement in this format:"
|
516 |
+
"\n{invocation_format}\\\\n**{system_prompt}**"
|
517 |
+
)
|
518 |
+
|
519 |
+
class DataVisualizer(Agent):
|
520 |
+
def __init__(self, model_path, research_prompt, chart_prompt, finish_pattern="Final Answer", max_turn=10):
|
521 |
+
super().__init__()
|
522 |
+
llm = GPTAPI(model_path, key='YOUR_OPENAI_API_KEY', retry=5, max_new_tokens=1024, stop_words=["```\n"])
|
523 |
+
interpreter, browser = IPythonInterpreter(), WebBrowser("BingSearch", api_key="YOUR_BING_API_KEY")
|
524 |
+
self.researcher = Agent(
|
525 |
+
llm,
|
526 |
+
TOOL_TEMPLATE.format(
|
527 |
+
finish_pattern=finish_pattern,
|
528 |
+
tool_description=get_plugin_prompt(browser),
|
529 |
+
invocation_format='```json\n{"name": {{tool name}}, "parameters": {{keyword arguments}}}\n```\n',
|
530 |
+
system_prompt=research_prompt,
|
531 |
+
),
|
532 |
+
output_format=ToolParser(
|
533 |
+
"browser",
|
534 |
+
begin="```json\n",
|
535 |
+
end="\n```\n",
|
536 |
+
validate=lambda x: json.loads(x.rstrip('`')),
|
537 |
+
),
|
538 |
+
aggregator=InternLMToolAggregator(),
|
539 |
+
name="researcher",
|
540 |
+
)
|
541 |
+
self.charter = Agent(
|
542 |
+
llm,
|
543 |
+
TOOL_TEMPLATE.format(
|
544 |
+
finish_pattern=finish_pattern,
|
545 |
+
tool_description=interpreter.name,
|
546 |
+
invocation_format='```python\n{{code}}\n```\n',
|
547 |
+
system_prompt=chart_prompt,
|
548 |
+
),
|
549 |
+
output_format=ToolParser(
|
550 |
+
"interpreter",
|
551 |
+
begin="```python\n",
|
552 |
+
end="\n```\n",
|
553 |
+
validate=lambda x: x.rstrip('`'),
|
554 |
+
),
|
555 |
+
aggregator=InternLMToolAggregator(),
|
556 |
+
name="charter",
|
557 |
+
)
|
558 |
+
self.executor = ActionExecutor([interpreter, browser], hooks=[InternLMActionProcessor()])
|
559 |
+
self.finish_pattern = finish_pattern
|
560 |
+
self.max_turn = max_turn
|
561 |
+
|
562 |
+
def forward(self, message, session_id=0):
|
563 |
+
for _ in range(self.max_turn):
|
564 |
+
message = self.researcher(message, session_id=session_id, stop_words=["```\n", "```python"]) # override llm stop words
|
565 |
+
while message.formatted["tool_type"]:
|
566 |
+
message = self.executor(message, session_id=session_id)
|
567 |
+
message = self.researcher(message, session_id=session_id, stop_words=["```\n", "```python"])
|
568 |
+
if self.finish_pattern in message.content:
|
569 |
+
return message
|
570 |
+
message = self.charter(message)
|
571 |
+
while message.formatted["tool_type"]:
|
572 |
+
message = self.executor(message, session_id=session_id)
|
573 |
+
message = self.charter(message, session_id=session_id)
|
574 |
+
if self.finish_pattern in message.content:
|
575 |
+
return message
|
576 |
+
return message
|
577 |
+
|
578 |
+
visualizer = DataVisualizer(
|
579 |
+
"gpt-4o-2024-05-13",
|
580 |
+
research_prompt="You should provide accurate data for the chart generator to use.",
|
581 |
+
chart_prompt="Any charts you display will be visible by the user.",
|
582 |
+
)
|
583 |
+
user_msg = AgentMessage(
|
584 |
+
sender='user',
|
585 |
+
content="Fetch the China's GDP over the past 5 years, then draw a line graph of it. Once you code it up, finish.")
|
586 |
+
bot_msg = visualizer(user_msg)
|
587 |
+
print(bot_msg.content)
|
588 |
+
json.dump(visualizer.state_dict(), open('visualizer.json', 'w'), ensure_ascii=False, indent=4)
|
589 |
+
````
|
590 |
+
|
591 |
+
## Citation
|
592 |
+
|
593 |
+
If you find this project useful in your research, please consider cite:
|
594 |
+
|
595 |
+
```latex
|
596 |
+
@misc{lagent2023,
|
597 |
+
title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents},
|
598 |
+
author={Lagent Developer Team},
|
599 |
+
howpublished = {\url{https://github.com/InternLM/lagent}},
|
600 |
+
year={2023}
|
601 |
+
}
|
602 |
+
```
|
603 |
+
|
604 |
+
## License
|
605 |
+
|
606 |
+
This project is released under the [Apache 2.0 license](LICENSE).
|
607 |
+
|
608 |
+
<p align="right"><a href="#top">🔼 Back to top</a></p>
|
lagent.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE
|
2 |
+
MANIFEST.in
|
3 |
+
README.md
|
4 |
+
setup.cfg
|
5 |
+
setup.py
|
6 |
+
lagent/__init__.py
|
7 |
+
lagent/schema.py
|
8 |
+
lagent/version.py
|
9 |
+
lagent.egg-info/PKG-INFO
|
10 |
+
lagent.egg-info/SOURCES.txt
|
11 |
+
lagent.egg-info/dependency_links.txt
|
12 |
+
lagent.egg-info/requires.txt
|
13 |
+
lagent.egg-info/top_level.txt
|
14 |
+
lagent/actions/__init__.py
|
15 |
+
lagent/actions/action_executor.py
|
16 |
+
lagent/actions/arxiv_search.py
|
17 |
+
lagent/actions/base_action.py
|
18 |
+
lagent/actions/bing_map.py
|
19 |
+
lagent/actions/builtin_actions.py
|
20 |
+
lagent/actions/google_scholar_search.py
|
21 |
+
lagent/actions/google_search.py
|
22 |
+
lagent/actions/ipython_interactive.py
|
23 |
+
lagent/actions/ipython_interpreter.py
|
24 |
+
lagent/actions/ipython_manager.py
|
25 |
+
lagent/actions/parser.py
|
26 |
+
lagent/actions/ppt.py
|
27 |
+
lagent/actions/python_interpreter.py
|
28 |
+
lagent/actions/web_browser.py
|
29 |
+
lagent/agents/__init__.py
|
30 |
+
lagent/agents/agent.py
|
31 |
+
lagent/agents/react.py
|
32 |
+
lagent/agents/stream.py
|
33 |
+
lagent/agents/aggregator/__init__.py
|
34 |
+
lagent/agents/aggregator/default_aggregator.py
|
35 |
+
lagent/agents/aggregator/tool_aggregator.py
|
36 |
+
lagent/distributed/__init__.py
|
37 |
+
lagent/distributed/http_serve/__init__.py
|
38 |
+
lagent/distributed/http_serve/api_server.py
|
39 |
+
lagent/distributed/http_serve/app.py
|
40 |
+
lagent/distributed/ray_serve/__init__.py
|
41 |
+
lagent/distributed/ray_serve/ray_warpper.py
|
42 |
+
lagent/hooks/__init__.py
|
43 |
+
lagent/hooks/action_preprocessor.py
|
44 |
+
lagent/hooks/hook.py
|
45 |
+
lagent/hooks/logger.py
|
46 |
+
lagent/llms/__init__.py
|
47 |
+
lagent/llms/base_api.py
|
48 |
+
lagent/llms/base_llm.py
|
49 |
+
lagent/llms/huggingface.py
|
50 |
+
lagent/llms/lmdeploy_wrapper.py
|
51 |
+
lagent/llms/meta_template.py
|
52 |
+
lagent/llms/openai.py
|
53 |
+
lagent/llms/sensenova.py
|
54 |
+
lagent/llms/vllm_wrapper.py
|
55 |
+
lagent/memory/__init__.py
|
56 |
+
lagent/memory/base_memory.py
|
57 |
+
lagent/memory/manager.py
|
58 |
+
lagent/prompts/__init__.py
|
59 |
+
lagent/prompts/prompt_template.py
|
60 |
+
lagent/prompts/parsers/__init__.py
|
61 |
+
lagent/prompts/parsers/custom_parser.py
|
62 |
+
lagent/prompts/parsers/json_parser.py
|
63 |
+
lagent/prompts/parsers/str_parser.py
|
64 |
+
lagent/prompts/parsers/tool_parser.py
|
65 |
+
lagent/utils/__init__.py
|
66 |
+
lagent/utils/gen_key.py
|
67 |
+
lagent/utils/package.py
|
68 |
+
lagent/utils/util.py
|
69 |
+
requirements/docs.txt
|
70 |
+
requirements/optional.txt
|
71 |
+
requirements/runtime.txt
|
lagent.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
lagent.egg-info/requires.txt
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp
|
2 |
+
arxiv
|
3 |
+
asyncache
|
4 |
+
asyncer
|
5 |
+
distro
|
6 |
+
duckduckgo_search==5.3.1b1
|
7 |
+
filelock
|
8 |
+
func_timeout
|
9 |
+
griffe<1.0
|
10 |
+
json5
|
11 |
+
jsonschema
|
12 |
+
jupyter==1.0.0
|
13 |
+
jupyter_client==8.6.2
|
14 |
+
jupyter_core==5.7.2
|
15 |
+
pydantic==2.6.4
|
16 |
+
requests
|
17 |
+
termcolor
|
18 |
+
tiktoken
|
19 |
+
timeout-decorator
|
20 |
+
typing-extensions
|
21 |
+
|
22 |
+
[all]
|
23 |
+
google-search-results
|
24 |
+
lmdeploy>=0.2.5
|
25 |
+
pillow
|
26 |
+
python-pptx
|
27 |
+
timeout_decorator
|
28 |
+
torch
|
29 |
+
transformers<=4.40,>=4.34
|
30 |
+
vllm>=0.3.3
|
31 |
+
aiohttp
|
32 |
+
arxiv
|
33 |
+
asyncache
|
34 |
+
asyncer
|
35 |
+
distro
|
36 |
+
duckduckgo_search==5.3.1b1
|
37 |
+
filelock
|
38 |
+
func_timeout
|
39 |
+
griffe<1.0
|
40 |
+
json5
|
41 |
+
jsonschema
|
42 |
+
jupyter==1.0.0
|
43 |
+
jupyter_client==8.6.2
|
44 |
+
jupyter_core==5.7.2
|
45 |
+
pydantic==2.6.4
|
46 |
+
requests
|
47 |
+
termcolor
|
48 |
+
tiktoken
|
49 |
+
typing-extensions
|
50 |
+
|
51 |
+
[optional]
|
52 |
+
google-search-results
|
53 |
+
lmdeploy>=0.2.5
|
54 |
+
pillow
|
55 |
+
python-pptx
|
56 |
+
timeout_decorator
|
57 |
+
torch
|
58 |
+
transformers<=4.40,>=4.34
|
59 |
+
vllm>=0.3.3
|
lagent.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
lagent
|
lagent/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (231 Bytes). View file
|
|
lagent/__pycache__/schema.cpython-310.pyc
ADDED
Binary file (3.46 kB). View file
|
|
lagent/__pycache__/version.cpython-310.pyc
ADDED
Binary file (744 Bytes). View file
|
|
lagent/actions/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from .action_executor import ActionExecutor, AsyncActionExecutor
|
2 |
from .arxiv_search import ArxivSearch, AsyncArxivSearch
|
3 |
-
from .base_action import
|
4 |
from .bing_map import AsyncBINGMap, BINGMap
|
5 |
from .builtin_actions import FinishAction, InvalidAction, NoAction
|
6 |
from .google_scholar_search import AsyncGoogleScholar, GoogleScholar
|
@@ -14,34 +14,23 @@ from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter
|
|
14 |
from .web_browser import AsyncWebBrowser, WebBrowser
|
15 |
|
16 |
__all__ = [
|
17 |
-
'BaseAction',
|
18 |
-
'
|
19 |
-
'
|
20 |
-
'
|
21 |
-
'
|
22 |
-
'
|
23 |
-
'
|
24 |
-
'
|
25 |
-
'ArxivSearch',
|
26 |
-
'AsyncArxivSearch',
|
27 |
-
'GoogleSearch',
|
28 |
-
'AsyncGoogleSearch',
|
29 |
-
'GoogleScholar',
|
30 |
-
'AsyncGoogleScholar',
|
31 |
-
'IPythonInterpreter',
|
32 |
-
'AsyncIPythonInterpreter',
|
33 |
-
'IPythonInteractive',
|
34 |
-
'AsyncIPythonInteractive',
|
35 |
-
'IPythonInteractiveManager',
|
36 |
-
'PythonInterpreter',
|
37 |
-
'AsyncPythonInterpreter',
|
38 |
-
'PPT',
|
39 |
-
'AsyncPPT',
|
40 |
-
'WebBrowser',
|
41 |
-
'AsyncWebBrowser',
|
42 |
-
'BaseParser',
|
43 |
-
'JsonParser',
|
44 |
-
'TupleParser',
|
45 |
-
'tool_api',
|
46 |
-
'AsyncActionMixin',
|
47 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from .action_executor import ActionExecutor, AsyncActionExecutor
|
2 |
from .arxiv_search import ArxivSearch, AsyncArxivSearch
|
3 |
+
from .base_action import BaseAction, tool_api
|
4 |
from .bing_map import AsyncBINGMap, BINGMap
|
5 |
from .builtin_actions import FinishAction, InvalidAction, NoAction
|
6 |
from .google_scholar_search import AsyncGoogleScholar, GoogleScholar
|
|
|
14 |
from .web_browser import AsyncWebBrowser, WebBrowser
|
15 |
|
16 |
__all__ = [
|
17 |
+
'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction',
|
18 |
+
'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch',
|
19 |
+
'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar',
|
20 |
+
'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter',
|
21 |
+
'IPythonInteractive', 'AsyncIPythonInteractive',
|
22 |
+
'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter',
|
23 |
+
'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser',
|
24 |
+
'JsonParser', 'TupleParser', 'tool_api'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
]
|
26 |
+
from .weather_query import WeatherQuery
|
27 |
+
__all__ = [
|
28 |
+
'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction',
|
29 |
+
'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch',
|
30 |
+
'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar',
|
31 |
+
'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter',
|
32 |
+
'IPythonInteractive', 'AsyncIPythonInteractive',
|
33 |
+
'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter',
|
34 |
+
'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser',
|
35 |
+
'JsonParser', 'TupleParser', 'tool_api', 'WeatherQuery' # 这里
|
36 |
+
]
|
lagent/actions/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.63 kB). View file
|
|
lagent/actions/__pycache__/action_executor.cpython-310.pyc
ADDED
Binary file (5.84 kB). View file
|
|
lagent/actions/__pycache__/arxiv_search.cpython-310.pyc
ADDED
Binary file (3.19 kB). View file
|
|
lagent/actions/__pycache__/base_action.cpython-310.pyc
ADDED
Binary file (11.6 kB). View file
|
|
lagent/actions/__pycache__/bing_map.cpython-310.pyc
ADDED
Binary file (7.79 kB). View file
|
|
lagent/actions/__pycache__/builtin_actions.cpython-310.pyc
ADDED
Binary file (3.89 kB). View file
|
|
lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc
ADDED
Binary file (13 kB). View file
|
|
lagent/actions/__pycache__/google_search.cpython-310.pyc
ADDED
Binary file (6.93 kB). View file
|
|
lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc
ADDED
Binary file (8.41 kB). View file
|
|
lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc
ADDED
Binary file (16.6 kB). View file
|
|
lagent/actions/__pycache__/ipython_manager.cpython-310.pyc
ADDED
Binary file (7.11 kB). View file
|
|
lagent/actions/__pycache__/parser.cpython-310.pyc
ADDED
Binary file (5.48 kB). View file
|
|
lagent/actions/__pycache__/ppt.cpython-310.pyc
ADDED
Binary file (6.81 kB). View file
|
|
lagent/actions/__pycache__/python_interpreter.cpython-310.pyc
ADDED
Binary file (5.38 kB). View file
|
|
lagent/actions/__pycache__/weather_query.cpython-310.pyc
ADDED
Binary file (2.66 kB). View file
|
|
lagent/actions/__pycache__/web_browser.cpython-310.pyc
ADDED
Binary file (28.8 kB). View file
|
|
lagent/actions/base_action.py
CHANGED
@@ -4,7 +4,7 @@ import re
|
|
4 |
from abc import ABCMeta
|
5 |
from copy import deepcopy
|
6 |
from functools import wraps
|
7 |
-
from typing import Callable,
|
8 |
|
9 |
try:
|
10 |
from typing import Annotated
|
@@ -24,15 +24,11 @@ from .parser import BaseParser, JsonParser, ParseError
|
|
24 |
logging.getLogger('griffe').setLevel(logging.ERROR)
|
25 |
|
26 |
|
27 |
-
def tool_api(
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
include_arguments: Optional[Iterable[str]] = None,
|
33 |
-
exclude_arguments: Optional[Iterable[str]] = None,
|
34 |
-
**kwargs,
|
35 |
-
):
|
36 |
"""Turn functions into tools. It will parse typehints as well as docstrings
|
37 |
to build the tool description and attach it to functions via an attribute
|
38 |
``api_description``.
|
@@ -94,16 +90,6 @@ def tool_api(
|
|
94 |
``return_data`` field will be added to ``api_description`` only
|
95 |
when ``explode_return`` or ``returns_named_value`` is enabled.
|
96 |
"""
|
97 |
-
if include_arguments is None:
|
98 |
-
exclude_arguments = exclude_arguments or set()
|
99 |
-
if isinstance(exclude_arguments, str):
|
100 |
-
exclude_arguments = {exclude_arguments}
|
101 |
-
elif not isinstance(exclude_arguments, set):
|
102 |
-
exclude_arguments = set(exclude_arguments)
|
103 |
-
if 'self' not in exclude_arguments:
|
104 |
-
exclude_arguments.add('self')
|
105 |
-
else:
|
106 |
-
include_arguments = {include_arguments} if isinstance(include_arguments, str) else set(include_arguments)
|
107 |
|
108 |
def _detect_type(string):
|
109 |
field_type = 'STRING'
|
@@ -120,9 +106,10 @@ def tool_api(
|
|
120 |
|
121 |
def _explode(desc):
|
122 |
kvs = []
|
123 |
-
desc = '\nArgs:\n' + '\n'.join(
|
124 |
-
|
125 |
-
|
|
|
126 |
docs = Docstring(desc).parse('google')
|
127 |
if not docs:
|
128 |
return kvs
|
@@ -138,12 +125,13 @@ def tool_api(
|
|
138 |
|
139 |
def _parse_tool(function):
|
140 |
# remove rst syntax
|
141 |
-
docs = Docstring(
|
142 |
-
'
|
143 |
-
|
144 |
desc = dict(
|
145 |
name=function.__name__,
|
146 |
-
description=docs[0].value
|
|
|
147 |
parameters=[],
|
148 |
required=[],
|
149 |
)
|
@@ -167,14 +155,17 @@ def tool_api(
|
|
167 |
|
168 |
sig = inspect.signature(function)
|
169 |
for name, param in sig.parameters.items():
|
170 |
-
if name
|
171 |
continue
|
172 |
parameter = dict(
|
173 |
-
name=param.name,
|
174 |
-
|
|
|
|
|
175 |
annotation = param.annotation
|
176 |
if annotation is inspect.Signature.empty:
|
177 |
-
parameter['type'] = args_doc.get(param.name,
|
|
|
178 |
else:
|
179 |
if get_origin(annotation) is Annotated:
|
180 |
annotation, info = get_args(annotation)
|
@@ -238,8 +229,9 @@ class ToolMeta(ABCMeta):
|
|
238 |
|
239 |
def __new__(mcs, name, base, attrs):
|
240 |
is_toolkit, tool_desc = True, dict(
|
241 |
-
name=name,
|
242 |
-
|
|
|
243 |
for key, value in attrs.items():
|
244 |
if callable(value) and hasattr(value, 'api_description'):
|
245 |
api_desc = getattr(value, 'api_description')
|
@@ -254,7 +246,8 @@ class ToolMeta(ABCMeta):
|
|
254 |
else:
|
255 |
tool_desc.setdefault('api_list', []).append(api_desc)
|
256 |
if not is_toolkit and 'api_list' in tool_desc:
|
257 |
-
raise KeyError('`run` and other tool APIs can not be implemented '
|
|
|
258 |
if is_toolkit and 'api_list' not in tool_desc:
|
259 |
is_toolkit = False
|
260 |
if callable(attrs.get('run')):
|
@@ -353,16 +346,26 @@ class BaseAction(metaclass=ToolMeta):
|
|
353 |
fallback_args = {'inputs': inputs, 'name': name}
|
354 |
if not hasattr(self, name):
|
355 |
return ActionReturn(
|
356 |
-
fallback_args,
|
357 |
-
|
|
|
|
|
358 |
try:
|
359 |
inputs = self._parser.parse_inputs(inputs, name)
|
360 |
except ParseError as exc:
|
361 |
-
return ActionReturn(
|
|
|
|
|
|
|
|
|
362 |
try:
|
363 |
outputs = getattr(self, name)(**inputs)
|
364 |
except Exception as exc:
|
365 |
-
return ActionReturn(
|
|
|
|
|
|
|
|
|
366 |
if isinstance(outputs, ActionReturn):
|
367 |
action_return = outputs
|
368 |
if not action_return.args:
|
@@ -399,16 +402,26 @@ class AsyncActionMixin:
|
|
399 |
fallback_args = {'inputs': inputs, 'name': name}
|
400 |
if not hasattr(self, name):
|
401 |
return ActionReturn(
|
402 |
-
fallback_args,
|
403 |
-
|
|
|
|
|
404 |
try:
|
405 |
inputs = self._parser.parse_inputs(inputs, name)
|
406 |
except ParseError as exc:
|
407 |
-
return ActionReturn(
|
|
|
|
|
|
|
|
|
408 |
try:
|
409 |
outputs = await getattr(self, name)(**inputs)
|
410 |
except Exception as exc:
|
411 |
-
return ActionReturn(
|
|
|
|
|
|
|
|
|
412 |
if isinstance(outputs, ActionReturn):
|
413 |
action_return = outputs
|
414 |
if not action_return.args:
|
|
|
4 |
from abc import ABCMeta
|
5 |
from copy import deepcopy
|
6 |
from functools import wraps
|
7 |
+
from typing import Callable, Optional, Type, get_args, get_origin
|
8 |
|
9 |
try:
|
10 |
from typing import Annotated
|
|
|
24 |
logging.getLogger('griffe').setLevel(logging.ERROR)
|
25 |
|
26 |
|
27 |
+
def tool_api(func: Optional[Callable] = None,
|
28 |
+
*,
|
29 |
+
explode_return: bool = False,
|
30 |
+
returns_named_value: bool = False,
|
31 |
+
**kwargs):
|
|
|
|
|
|
|
|
|
32 |
"""Turn functions into tools. It will parse typehints as well as docstrings
|
33 |
to build the tool description and attach it to functions via an attribute
|
34 |
``api_description``.
|
|
|
90 |
``return_data`` field will be added to ``api_description`` only
|
91 |
when ``explode_return`` or ``returns_named_value`` is enabled.
|
92 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
def _detect_type(string):
|
95 |
field_type = 'STRING'
|
|
|
106 |
|
107 |
def _explode(desc):
|
108 |
kvs = []
|
109 |
+
desc = '\nArgs:\n' + '\n'.join([
|
110 |
+
' ' + item.lstrip(' -+*#.')
|
111 |
+
for item in desc.split('\n')[1:] if item.strip()
|
112 |
+
])
|
113 |
docs = Docstring(desc).parse('google')
|
114 |
if not docs:
|
115 |
return kvs
|
|
|
125 |
|
126 |
def _parse_tool(function):
|
127 |
# remove rst syntax
|
128 |
+
docs = Docstring(
|
129 |
+
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
|
130 |
+
'google', returns_named_value=returns_named_value, **kwargs)
|
131 |
desc = dict(
|
132 |
name=function.__name__,
|
133 |
+
description=docs[0].value
|
134 |
+
if docs[0].kind is DocstringSectionKind.text else '',
|
135 |
parameters=[],
|
136 |
required=[],
|
137 |
)
|
|
|
155 |
|
156 |
sig = inspect.signature(function)
|
157 |
for name, param in sig.parameters.items():
|
158 |
+
if name == 'self':
|
159 |
continue
|
160 |
parameter = dict(
|
161 |
+
name=param.name,
|
162 |
+
type='STRING',
|
163 |
+
description=args_doc.get(param.name,
|
164 |
+
{}).get('description', ''))
|
165 |
annotation = param.annotation
|
166 |
if annotation is inspect.Signature.empty:
|
167 |
+
parameter['type'] = args_doc.get(param.name,
|
168 |
+
{}).get('type', 'STRING')
|
169 |
else:
|
170 |
if get_origin(annotation) is Annotated:
|
171 |
annotation, info = get_args(annotation)
|
|
|
229 |
|
230 |
def __new__(mcs, name, base, attrs):
|
231 |
is_toolkit, tool_desc = True, dict(
|
232 |
+
name=name,
|
233 |
+
description=Docstring(attrs.get('__doc__',
|
234 |
+
'')).parse('google')[0].value)
|
235 |
for key, value in attrs.items():
|
236 |
if callable(value) and hasattr(value, 'api_description'):
|
237 |
api_desc = getattr(value, 'api_description')
|
|
|
246 |
else:
|
247 |
tool_desc.setdefault('api_list', []).append(api_desc)
|
248 |
if not is_toolkit and 'api_list' in tool_desc:
|
249 |
+
raise KeyError('`run` and other tool APIs can not be implemented '
|
250 |
+
'at the same time')
|
251 |
if is_toolkit and 'api_list' not in tool_desc:
|
252 |
is_toolkit = False
|
253 |
if callable(attrs.get('run')):
|
|
|
346 |
fallback_args = {'inputs': inputs, 'name': name}
|
347 |
if not hasattr(self, name):
|
348 |
return ActionReturn(
|
349 |
+
fallback_args,
|
350 |
+
type=self.name,
|
351 |
+
errmsg=f'invalid API: {name}',
|
352 |
+
state=ActionStatusCode.API_ERROR)
|
353 |
try:
|
354 |
inputs = self._parser.parse_inputs(inputs, name)
|
355 |
except ParseError as exc:
|
356 |
+
return ActionReturn(
|
357 |
+
fallback_args,
|
358 |
+
type=self.name,
|
359 |
+
errmsg=exc.err_msg,
|
360 |
+
state=ActionStatusCode.ARGS_ERROR)
|
361 |
try:
|
362 |
outputs = getattr(self, name)(**inputs)
|
363 |
except Exception as exc:
|
364 |
+
return ActionReturn(
|
365 |
+
inputs,
|
366 |
+
type=self.name,
|
367 |
+
errmsg=str(exc),
|
368 |
+
state=ActionStatusCode.API_ERROR)
|
369 |
if isinstance(outputs, ActionReturn):
|
370 |
action_return = outputs
|
371 |
if not action_return.args:
|
|
|
402 |
fallback_args = {'inputs': inputs, 'name': name}
|
403 |
if not hasattr(self, name):
|
404 |
return ActionReturn(
|
405 |
+
fallback_args,
|
406 |
+
type=self.name,
|
407 |
+
errmsg=f'invalid API: {name}',
|
408 |
+
state=ActionStatusCode.API_ERROR)
|
409 |
try:
|
410 |
inputs = self._parser.parse_inputs(inputs, name)
|
411 |
except ParseError as exc:
|
412 |
+
return ActionReturn(
|
413 |
+
fallback_args,
|
414 |
+
type=self.name,
|
415 |
+
errmsg=exc.err_msg,
|
416 |
+
state=ActionStatusCode.ARGS_ERROR)
|
417 |
try:
|
418 |
outputs = await getattr(self, name)(**inputs)
|
419 |
except Exception as exc:
|
420 |
+
return ActionReturn(
|
421 |
+
inputs,
|
422 |
+
type=self.name,
|
423 |
+
errmsg=str(exc),
|
424 |
+
state=ActionStatusCode.API_ERROR)
|
425 |
if isinstance(outputs, ActionReturn):
|
426 |
action_return = outputs
|
427 |
if not action_return.args:
|
lagent/actions/weather_query.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from lagent.actions.base_action import BaseAction, tool_api
|
4 |
+
from lagent.schema import ActionReturn, ActionStatusCode
|
5 |
+
|
6 |
+
class WeatherQuery(BaseAction):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
self.api_key = os.getenv("weather_token")
|
10 |
+
print(self.api_key)
|
11 |
+
if not self.api_key:
|
12 |
+
raise EnvironmentError("未找到环境变量 'token'。请设置你的和风天气 API Key 到 'weather_token' 环境变量中,比如export weather_token='xxx' ")
|
13 |
+
|
14 |
+
@tool_api
|
15 |
+
def run(self, location: str) -> dict:
|
16 |
+
"""
|
17 |
+
查询实时天气信息。
|
18 |
+
|
19 |
+
Args:
|
20 |
+
location (str): 要查询的地点名称、LocationID 或经纬度坐标(如 "101010100" 或 "116.41,39.92")。
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
dict: 包含天气信息的字典
|
24 |
+
* location: 地点名称
|
25 |
+
* weather: 天气状况
|
26 |
+
* temperature: 当前温度
|
27 |
+
* wind_direction: 风向
|
28 |
+
* wind_speed: 风速(公里/小时)
|
29 |
+
* humidity: 相对湿度(%)
|
30 |
+
* report_time: 数据报告时间
|
31 |
+
"""
|
32 |
+
try:
|
33 |
+
# 如果 location 不是坐标格式(例如 "116.41,39.92"),则调用 GeoAPI 获取 LocationID
|
34 |
+
if not ("," in location and location.replace(",", "").replace(".", "").isdigit()):
|
35 |
+
# 使用 GeoAPI 获取 LocationID
|
36 |
+
geo_url = f"https://geoapi.qweather.com/v2/city/lookup?location={location}&key={self.api_key}"
|
37 |
+
geo_response = requests.get(geo_url)
|
38 |
+
geo_data = geo_response.json()
|
39 |
+
|
40 |
+
if geo_data.get("code") != "200" or not geo_data.get("location"):
|
41 |
+
raise Exception(f"GeoAPI 返回错误码:{geo_data.get('code')} 或未找到位置")
|
42 |
+
|
43 |
+
location = geo_data["location"][0]["id"]
|
44 |
+
|
45 |
+
# 构建天气查询的 API 请求 URL
|
46 |
+
weather_url = f"https://devapi.qweather.com/v7/weather/now?location={location}&key={self.api_key}"
|
47 |
+
response = requests.get(weather_url)
|
48 |
+
data = response.json()
|
49 |
+
|
50 |
+
# 检查 API 响应码
|
51 |
+
if data.get("code") != "200":
|
52 |
+
raise Exception(f"Weather API 返回错误码:{data.get('code')}")
|
53 |
+
|
54 |
+
# 解析和组织天气信息
|
55 |
+
weather_info = {
|
56 |
+
"location": location,
|
57 |
+
"weather": data["now"]["text"],
|
58 |
+
"temperature": data["now"]["temp"] + "°C",
|
59 |
+
"wind_direction": data["now"]["windDir"],
|
60 |
+
"wind_speed": data["now"]["windSpeed"] + " km/h",
|
61 |
+
"humidity": data["now"]["humidity"] + "%",
|
62 |
+
"report_time": data["updateTime"]
|
63 |
+
}
|
64 |
+
|
65 |
+
return {"result": weather_info}
|
66 |
+
|
67 |
+
except Exception as exc:
|
68 |
+
return ActionReturn(
|
69 |
+
errmsg=f"WeatherQuery 异常:{exc}",
|
70 |
+
state=ActionStatusCode.HTTP_ERROR
|
71 |
+
)
|
lagent/actions/web_browser.py
CHANGED
@@ -18,6 +18,7 @@ import requests
|
|
18 |
from asyncache import cached as acached
|
19 |
from bs4 import BeautifulSoup
|
20 |
from cachetools import TTLCache, cached
|
|
|
21 |
|
22 |
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
23 |
from lagent.actions.parser import BaseParser, JsonParser
|
@@ -34,11 +35,12 @@ class BaseSearch:
|
|
34 |
filtered_results = {}
|
35 |
count = 0
|
36 |
for url, snippet, title in results:
|
37 |
-
if all(domain not in url
|
|
|
38 |
filtered_results[count] = {
|
39 |
'url': url,
|
40 |
'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
|
41 |
-
'title': title
|
42 |
}
|
43 |
count += 1
|
44 |
if count >= self.topk:
|
@@ -48,17 +50,15 @@ class BaseSearch:
|
|
48 |
|
49 |
class DuckDuckGoSearch(BaseSearch):
|
50 |
|
51 |
-
def __init__(
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
**kwargs,
|
61 |
-
):
|
62 |
self.proxy = kwargs.get('proxy')
|
63 |
self.timeout = kwargs.get('timeout', 30)
|
64 |
super().__init__(topk, black_list)
|
@@ -67,39 +67,40 @@ class DuckDuckGoSearch(BaseSearch):
|
|
67 |
def search(self, query: str, max_retry: int = 3) -> dict:
|
68 |
for attempt in range(max_retry):
|
69 |
try:
|
70 |
-
response = self._call_ddgs(
|
|
|
71 |
return self._parse_response(response)
|
72 |
except Exception as e:
|
73 |
logging.exception(str(e))
|
74 |
-
warnings.warn(
|
|
|
75 |
time.sleep(random.randint(2, 5))
|
76 |
-
raise Exception(
|
|
|
77 |
|
78 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
79 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
80 |
-
from duckduckgo_search import AsyncDDGS
|
81 |
-
|
82 |
for attempt in range(max_retry):
|
83 |
try:
|
84 |
ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
|
85 |
-
response = await ddgs.
|
86 |
return self._parse_response(response)
|
87 |
except Exception as e:
|
88 |
if isinstance(e, asyncio.TimeoutError):
|
89 |
logging.exception('Request to DDGS timed out.')
|
90 |
logging.exception(str(e))
|
91 |
-
warnings.warn(
|
|
|
92 |
await asyncio.sleep(random.randint(2, 5))
|
93 |
-
raise Exception(
|
|
|
94 |
|
95 |
async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
|
96 |
-
from duckduckgo_search import DDGS
|
97 |
-
|
98 |
ddgs = DDGS(**kwargs)
|
99 |
try:
|
100 |
response = await asyncio.wait_for(
|
101 |
-
asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10),
|
102 |
-
|
103 |
return response
|
104 |
except asyncio.TimeoutError:
|
105 |
logging.exception('Request to DDGS timed out.')
|
@@ -109,35 +110,34 @@ class DuckDuckGoSearch(BaseSearch):
|
|
109 |
loop = asyncio.new_event_loop()
|
110 |
asyncio.set_event_loop(loop)
|
111 |
try:
|
112 |
-
response = loop.run_until_complete(
|
|
|
113 |
return response
|
114 |
finally:
|
115 |
loop.close()
|
116 |
|
117 |
-
def _parse_response(self, response:
|
118 |
raw_results = []
|
119 |
for item in response:
|
120 |
raw_results.append(
|
121 |
-
(item['href'], item['description']
|
122 |
-
|
123 |
return self._filter_results(raw_results)
|
124 |
|
125 |
|
126 |
class BingSearch(BaseSearch):
|
127 |
|
128 |
-
def __init__(
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
**kwargs,
|
140 |
-
):
|
141 |
self.api_key = api_key
|
142 |
self.market = region
|
143 |
self.proxy = kwargs.get('proxy')
|
@@ -151,9 +151,11 @@ class BingSearch(BaseSearch):
|
|
151 |
return self._parse_response(response)
|
152 |
except Exception as e:
|
153 |
logging.exception(str(e))
|
154 |
-
warnings.warn(
|
|
|
155 |
time.sleep(random.randint(2, 5))
|
156 |
-
raise Exception(
|
|
|
157 |
|
158 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
159 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
@@ -163,15 +165,18 @@ class BingSearch(BaseSearch):
|
|
163 |
return self._parse_response(response)
|
164 |
except Exception as e:
|
165 |
logging.exception(str(e))
|
166 |
-
warnings.warn(
|
|
|
167 |
await asyncio.sleep(random.randint(2, 5))
|
168 |
-
raise Exception(
|
|
|
169 |
|
170 |
def _call_bing_api(self, query: str) -> dict:
|
171 |
endpoint = 'https://api.bing.microsoft.com/v7.0/search'
|
172 |
params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
|
173 |
headers = {'Ocp-Apim-Subscription-Key': self.api_key}
|
174 |
-
response = requests.get(
|
|
|
175 |
response.raise_for_status()
|
176 |
return response.json()
|
177 |
|
@@ -181,25 +186,32 @@ class BingSearch(BaseSearch):
|
|
181 |
headers = {'Ocp-Apim-Subscription-Key': self.api_key}
|
182 |
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
183 |
async with session.get(
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
return await resp.json()
|
190 |
|
191 |
def _parse_response(self, response: dict) -> dict:
|
192 |
-
webpages = {
|
|
|
|
|
|
|
193 |
raw_results = []
|
194 |
|
195 |
-
for item in response.get('rankingResponse',
|
|
|
196 |
if item['answerType'] == 'WebPages':
|
197 |
webpage = webpages.get(item['value']['id'])
|
198 |
if webpage:
|
199 |
-
raw_results.append(
|
200 |
-
|
|
|
|
|
201 |
for news in response.get('news', {}).get('value', []):
|
202 |
-
raw_results.append(
|
|
|
203 |
|
204 |
return self._filter_results(raw_results)
|
205 |
|
@@ -218,27 +230,24 @@ class BraveSearch(BaseSearch):
|
|
218 |
topk (int): The number of search results returned in response from API search results.
|
219 |
region (str): The country code string. Specifies the country where the search results come from.
|
220 |
language (str): The language code string. Specifies the preferred language for the search results.
|
221 |
-
extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the
|
222 |
-
search results.
|
223 |
**kwargs: Any other parameters related to the Brave Search API. Find more details at
|
224 |
https://api.search.brave.com/app/documentation/web-search/get-started.
|
225 |
"""
|
226 |
|
227 |
-
def __init__(
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
**kwargs,
|
241 |
-
):
|
242 |
self.api_key = api_key
|
243 |
self.market = region
|
244 |
self.proxy = kwargs.get('proxy')
|
@@ -256,9 +265,11 @@ class BraveSearch(BaseSearch):
|
|
256 |
return self._parse_response(response)
|
257 |
except Exception as e:
|
258 |
logging.exception(str(e))
|
259 |
-
warnings.warn(
|
|
|
260 |
time.sleep(random.randint(2, 5))
|
261 |
-
raise Exception(
|
|
|
262 |
|
263 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
264 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
@@ -268,9 +279,11 @@ class BraveSearch(BaseSearch):
|
|
268 |
return self._parse_response(response)
|
269 |
except Exception as e:
|
270 |
logging.exception(str(e))
|
271 |
-
warnings.warn(
|
|
|
272 |
await asyncio.sleep(random.randint(2, 5))
|
273 |
-
raise Exception(
|
|
|
274 |
|
275 |
def _call_brave_api(self, query: str) -> dict:
|
276 |
endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
|
@@ -280,10 +293,17 @@ class BraveSearch(BaseSearch):
|
|
280 |
'search_lang': self.language,
|
281 |
'extra_snippets': self.extra_snippests,
|
282 |
'count': self.topk,
|
283 |
-
**{
|
|
|
|
|
|
|
284 |
}
|
285 |
-
headers = {
|
286 |
-
|
|
|
|
|
|
|
|
|
287 |
response.raise_for_status()
|
288 |
return response.json()
|
289 |
|
@@ -295,16 +315,22 @@ class BraveSearch(BaseSearch):
|
|
295 |
'search_lang': self.language,
|
296 |
'extra_snippets': self.extra_snippests,
|
297 |
'count': self.topk,
|
298 |
-
**{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
}
|
300 |
-
headers = {'X-Subscription-Token': self.api_key or '', 'Accept': 'application/json'}
|
301 |
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
302 |
async with session.get(
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
return await resp.json()
|
309 |
|
310 |
def _parse_response(self, response: dict) -> dict:
|
@@ -315,13 +341,15 @@ class BraveSearch(BaseSearch):
|
|
315 |
raw_results = []
|
316 |
|
317 |
for item in filtered_result:
|
318 |
-
raw_results.append(
|
319 |
-
(
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
|
|
325 |
return self._filter_results(raw_results)
|
326 |
|
327 |
|
@@ -348,18 +376,16 @@ class GoogleSearch(BaseSearch):
|
|
348 |
'search': 'organic',
|
349 |
}
|
350 |
|
351 |
-
def __init__(
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
**kwargs,
|
362 |
-
):
|
363 |
self.api_key = api_key
|
364 |
self.proxy = kwargs.get('proxy')
|
365 |
self.search_type = kwargs.get('search_type', 'search')
|
@@ -374,9 +400,12 @@ class GoogleSearch(BaseSearch):
|
|
374 |
return self._parse_response(response)
|
375 |
except Exception as e:
|
376 |
logging.exception(str(e))
|
377 |
-
warnings.warn(
|
|
|
378 |
time.sleep(random.randint(2, 5))
|
379 |
-
raise Exception(
|
|
|
|
|
380 |
|
381 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
382 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
@@ -386,19 +415,29 @@ class GoogleSearch(BaseSearch):
|
|
386 |
return self._parse_response(response)
|
387 |
except Exception as e:
|
388 |
logging.exception(str(e))
|
389 |
-
warnings.warn(
|
|
|
390 |
await asyncio.sleep(random.randint(2, 5))
|
391 |
-
raise Exception(
|
|
|
|
|
392 |
|
393 |
def _call_serper_api(self, query: str) -> dict:
|
394 |
endpoint = f'https://google.serper.dev/{self.search_type}'
|
395 |
params = {
|
396 |
'q': query,
|
397 |
'num': self.topk,
|
398 |
-
**{
|
|
|
|
|
|
|
399 |
}
|
400 |
-
headers = {
|
401 |
-
|
|
|
|
|
|
|
|
|
402 |
response.raise_for_status()
|
403 |
return response.json()
|
404 |
|
@@ -407,16 +446,22 @@ class GoogleSearch(BaseSearch):
|
|
407 |
params = {
|
408 |
'q': query,
|
409 |
'num': self.topk,
|
410 |
-
**{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
}
|
412 |
-
headers = {'X-API-KEY': self.api_key or '', 'Content-Type': 'application/json'}
|
413 |
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
414 |
async with session.get(
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
return await resp.json()
|
421 |
|
422 |
def _parse_response(self, response: dict) -> dict:
|
@@ -427,34 +472,33 @@ class GoogleSearch(BaseSearch):
|
|
427 |
if answer_box.get('answer'):
|
428 |
raw_results.append(('', answer_box.get('answer'), ''))
|
429 |
elif answer_box.get('snippet'):
|
430 |
-
raw_results.append(
|
|
|
431 |
elif answer_box.get('snippetHighlighted'):
|
432 |
-
raw_results.append(
|
|
|
433 |
|
434 |
if response.get('knowledgeGraph'):
|
435 |
kg = response.get('knowledgeGraph', {})
|
436 |
description = kg.get('description', '')
|
437 |
-
attributes = '. '.join(
|
|
|
|
|
438 |
raw_results.append(
|
439 |
-
(
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
for result in response[self.result_key_for_type[self.search_type]][: self.topk]:
|
447 |
description = result.get('snippet', '')
|
448 |
attributes = '. '.join(
|
449 |
-
f'{attribute}: {value}'
|
450 |
-
|
451 |
raw_results.append(
|
452 |
-
(
|
453 |
-
|
454 |
-
|
455 |
-
result.get('title', ''),
|
456 |
-
)
|
457 |
-
)
|
458 |
|
459 |
return self._filter_results(raw_results)
|
460 |
|
@@ -485,27 +529,25 @@ class TencentSearch(BaseSearch):
|
|
485 |
Supports multiple values separated by commas. Example: `30010255`.
|
486 |
"""
|
487 |
|
488 |
-
def __init__(
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
],
|
508 |
-
):
|
509 |
self.secret_id = secret_id
|
510 |
self.secret_key = secret_key
|
511 |
self.api_key = api_key
|
@@ -527,9 +569,11 @@ class TencentSearch(BaseSearch):
|
|
527 |
return self._parse_response(response)
|
528 |
except Exception as e:
|
529 |
logging.exception(str(e))
|
530 |
-
warnings.warn(
|
|
|
531 |
time.sleep(random.randint(2, 5))
|
532 |
-
raise Exception(
|
|
|
533 |
|
534 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
535 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
@@ -539,9 +583,11 @@ class TencentSearch(BaseSearch):
|
|
539 |
return self._parse_response(response)
|
540 |
except Exception as e:
|
541 |
logging.exception(str(e))
|
542 |
-
warnings.warn(
|
|
|
543 |
await asyncio.sleep(random.randint(2, 5))
|
544 |
-
raise Exception(
|
|
|
545 |
|
546 |
def _get_headers_and_payload(self, query: str) -> tuple:
|
547 |
|
@@ -571,47 +617,33 @@ class TencentSearch(BaseSearch):
|
|
571 |
ct = 'application/json; charset=utf-8'
|
572 |
canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
|
573 |
signed_headers = 'content-type;host;x-tc-action'
|
574 |
-
hashed_request_payload = hashlib.sha256(
|
|
|
575 |
canonical_request = (
|
576 |
-
http_request_method
|
577 |
-
+ '\n'
|
578 |
-
+
|
579 |
-
+ '\n'
|
580 |
-
+ canonical_querystring
|
581 |
-
+ '\n'
|
582 |
-
+ canonical_headers
|
583 |
-
+ '\n'
|
584 |
-
+ signed_headers
|
585 |
-
+ '\n'
|
586 |
-
+ hashed_request_payload
|
587 |
-
)
|
588 |
|
589 |
# ************* 步骤 2:拼接待签名字符串 *************
|
590 |
credential_scope = date + '/' + self.service + '/' + 'tc3_request'
|
591 |
-
hashed_canonical_request = hashlib.sha256(
|
592 |
-
|
|
|
|
|
|
|
593 |
|
594 |
# ************* 步骤 3:计算签名 *************
|
595 |
secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
|
596 |
secret_service = sign(secret_date, self.service)
|
597 |
secret_signing = sign(secret_service, 'tc3_request')
|
598 |
-
signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'),
|
|
|
599 |
|
600 |
# ************* 步骤 4:拼接 Authorization *************
|
601 |
authorization = (
|
602 |
-
algorithm
|
603 |
-
+ ' '
|
604 |
-
+ '
|
605 |
-
+ self.secret_id
|
606 |
-
+ '/'
|
607 |
-
+ credential_scope
|
608 |
-
+ ', '
|
609 |
-
+ 'SignedHeaders='
|
610 |
-
+ signed_headers
|
611 |
-
+ ', '
|
612 |
-
+ 'Signature='
|
613 |
-
+ signature
|
614 |
-
)
|
615 |
|
616 |
# ************* 步骤 5:构造并发起请求 *************
|
617 |
headers = {
|
@@ -620,7 +652,7 @@ class TencentSearch(BaseSearch):
|
|
620 |
'Host': self.host,
|
621 |
'X-TC-Action': self.action,
|
622 |
'X-TC-Timestamp': str(timestamp),
|
623 |
-
'X-TC-Version': self.version
|
624 |
}
|
625 |
# if self.region:
|
626 |
# headers["X-TC-Region"] = self.region
|
@@ -638,14 +670,16 @@ class TencentSearch(BaseSearch):
|
|
638 |
except Exception as e:
|
639 |
logging.warning(str(e))
|
640 |
import ast
|
641 |
-
|
642 |
resp = ast.literal_eval(resp)
|
643 |
return resp.get('Response', dict())
|
644 |
|
645 |
async def _async_call_tencent_api(self, query: str):
|
646 |
headers, payload = self._get_headers_and_payload(query)
|
647 |
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
648 |
-
async with session.post(
|
|
|
|
|
|
|
649 |
return (await resp.json()).get('Response', {})
|
650 |
|
651 |
def _parse_response(self, response: dict) -> dict:
|
@@ -654,7 +688,8 @@ class TencentSearch(BaseSearch):
|
|
654 |
display = json.loads(item['Display'])
|
655 |
if not display['url']:
|
656 |
continue
|
657 |
-
raw_results.append((display['url'], display['content']
|
|
|
658 |
return self._filter_results(raw_results)
|
659 |
|
660 |
|
@@ -680,8 +715,8 @@ class ContentFetcher:
|
|
680 |
async def afetch(self, url: str) -> Tuple[bool, str]:
|
681 |
try:
|
682 |
async with aiohttp.ClientSession(
|
683 |
-
|
684 |
-
|
685 |
async with session.get(url) as resp:
|
686 |
html = await resp.text(errors='ignore')
|
687 |
text = BeautifulSoup(html, 'html.parser').get_text()
|
@@ -692,24 +727,24 @@ class ContentFetcher:
|
|
692 |
|
693 |
|
694 |
class WebBrowser(BaseAction):
|
695 |
-
"""Wrapper around the Web Browser Tool.
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
self.fetcher = ContentFetcher(timeout=timeout)
|
714 |
self.search_results = None
|
715 |
super().__init__(description, parser)
|
@@ -724,7 +759,10 @@ class WebBrowser(BaseAction):
|
|
724 |
search_results = {}
|
725 |
|
726 |
with ThreadPoolExecutor() as executor:
|
727 |
-
future_to_query = {
|
|
|
|
|
|
|
728 |
|
729 |
for future in as_completed(future_to_query):
|
730 |
query = future_to_query[future]
|
@@ -737,9 +775,13 @@ class WebBrowser(BaseAction):
|
|
737 |
if result['url'] not in search_results:
|
738 |
search_results[result['url']] = result
|
739 |
else:
|
740 |
-
search_results[
|
|
|
741 |
|
742 |
-
self.search_results = {
|
|
|
|
|
|
|
743 |
return self.search_results
|
744 |
|
745 |
@tool_api
|
@@ -756,8 +798,7 @@ class WebBrowser(BaseAction):
|
|
756 |
with ThreadPoolExecutor() as executor:
|
757 |
future_to_id = {
|
758 |
executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id
|
759 |
-
for select_id in select_ids
|
760 |
-
if select_id in self.search_results
|
761 |
}
|
762 |
for future in as_completed(future_to_id):
|
763 |
select_id = future_to_id[future]
|
@@ -767,8 +808,10 @@ class WebBrowser(BaseAction):
|
|
767 |
warnings.warn(f'{select_id} generated an exception: {exc}')
|
768 |
else:
|
769 |
if web_success:
|
770 |
-
self.search_results[select_id][
|
771 |
-
|
|
|
|
|
772 |
new_search_results[select_id].pop('summ')
|
773 |
|
774 |
return new_search_results
|
@@ -784,12 +827,13 @@ class WebBrowser(BaseAction):
|
|
784 |
|
785 |
|
786 |
class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
|
787 |
-
"""Wrapper around the Web Browser Tool.
|
|
|
788 |
|
789 |
@tool_api
|
790 |
async def search(self, query: Union[str, List[str]]) -> dict:
|
791 |
"""BING search API
|
792 |
-
|
793 |
Args:
|
794 |
query (List[str]): list of search query strings
|
795 |
"""
|
@@ -812,9 +856,13 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
|
|
812 |
if result['url'] not in search_results:
|
813 |
search_results[result['url']] = result
|
814 |
else:
|
815 |
-
search_results[
|
|
|
816 |
|
817 |
-
self.search_results = {
|
|
|
|
|
|
|
818 |
return self.search_results
|
819 |
|
820 |
@tool_api
|
@@ -831,7 +879,8 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
|
|
831 |
tasks = []
|
832 |
for select_id in select_ids:
|
833 |
if select_id in self.search_results:
|
834 |
-
task = asyncio.create_task(
|
|
|
835 |
task.select_id = select_id
|
836 |
tasks.append(task)
|
837 |
async for future in async_as_completed(tasks):
|
@@ -842,8 +891,10 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
|
|
842 |
warnings.warn(f'{select_id} generated an exception: {exc}')
|
843 |
else:
|
844 |
if web_success:
|
845 |
-
self.search_results[select_id][
|
846 |
-
|
|
|
|
|
847 |
new_search_results[select_id].pop('summ')
|
848 |
return new_search_results
|
849 |
|
|
|
18 |
from asyncache import cached as acached
|
19 |
from bs4 import BeautifulSoup
|
20 |
from cachetools import TTLCache, cached
|
21 |
+
from duckduckgo_search import DDGS, AsyncDDGS
|
22 |
|
23 |
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
24 |
from lagent.actions.parser import BaseParser, JsonParser
|
|
|
35 |
filtered_results = {}
|
36 |
count = 0
|
37 |
for url, snippet, title in results:
|
38 |
+
if all(domain not in url
|
39 |
+
for domain in self.black_list) and not url.endswith('.pdf'):
|
40 |
filtered_results[count] = {
|
41 |
'url': url,
|
42 |
'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
|
43 |
+
'title': title
|
44 |
}
|
45 |
count += 1
|
46 |
if count >= self.topk:
|
|
|
50 |
|
51 |
class DuckDuckGoSearch(BaseSearch):
|
52 |
|
53 |
+
def __init__(self,
|
54 |
+
topk: int = 3,
|
55 |
+
black_list: List[str] = [
|
56 |
+
'enoN',
|
57 |
+
'youtube.com',
|
58 |
+
'bilibili.com',
|
59 |
+
'researchgate.net',
|
60 |
+
],
|
61 |
+
**kwargs):
|
|
|
|
|
62 |
self.proxy = kwargs.get('proxy')
|
63 |
self.timeout = kwargs.get('timeout', 30)
|
64 |
super().__init__(topk, black_list)
|
|
|
67 |
def search(self, query: str, max_retry: int = 3) -> dict:
|
68 |
for attempt in range(max_retry):
|
69 |
try:
|
70 |
+
response = self._call_ddgs(
|
71 |
+
query, timeout=self.timeout, proxy=self.proxy)
|
72 |
return self._parse_response(response)
|
73 |
except Exception as e:
|
74 |
logging.exception(str(e))
|
75 |
+
warnings.warn(
|
76 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
77 |
time.sleep(random.randint(2, 5))
|
78 |
+
raise Exception(
|
79 |
+
'Failed to get search results from DuckDuckGo after retries.')
|
80 |
|
81 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
82 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
|
|
|
|
83 |
for attempt in range(max_retry):
|
84 |
try:
|
85 |
ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
|
86 |
+
response = await ddgs.atext(query.strip("'"), max_results=10)
|
87 |
return self._parse_response(response)
|
88 |
except Exception as e:
|
89 |
if isinstance(e, asyncio.TimeoutError):
|
90 |
logging.exception('Request to DDGS timed out.')
|
91 |
logging.exception(str(e))
|
92 |
+
warnings.warn(
|
93 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
94 |
await asyncio.sleep(random.randint(2, 5))
|
95 |
+
raise Exception(
|
96 |
+
'Failed to get search results from DuckDuckGo after retries.')
|
97 |
|
98 |
async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
|
|
|
|
|
99 |
ddgs = DDGS(**kwargs)
|
100 |
try:
|
101 |
response = await asyncio.wait_for(
|
102 |
+
asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10),
|
103 |
+
timeout=self.timeout)
|
104 |
return response
|
105 |
except asyncio.TimeoutError:
|
106 |
logging.exception('Request to DDGS timed out.')
|
|
|
110 |
loop = asyncio.new_event_loop()
|
111 |
asyncio.set_event_loop(loop)
|
112 |
try:
|
113 |
+
response = loop.run_until_complete(
|
114 |
+
self._async_call_ddgs(query, **kwargs))
|
115 |
return response
|
116 |
finally:
|
117 |
loop.close()
|
118 |
|
119 |
+
def _parse_response(self, response: dict) -> dict:
|
120 |
raw_results = []
|
121 |
for item in response:
|
122 |
raw_results.append(
|
123 |
+
(item['href'], item['description']
|
124 |
+
if 'description' in item else item['body'], item['title']))
|
125 |
return self._filter_results(raw_results)
|
126 |
|
127 |
|
128 |
class BingSearch(BaseSearch):
|
129 |
|
130 |
+
def __init__(self,
|
131 |
+
api_key: str,
|
132 |
+
region: str = 'zh-CN',
|
133 |
+
topk: int = 3,
|
134 |
+
black_list: List[str] = [
|
135 |
+
'enoN',
|
136 |
+
'youtube.com',
|
137 |
+
'bilibili.com',
|
138 |
+
'researchgate.net',
|
139 |
+
],
|
140 |
+
**kwargs):
|
|
|
|
|
141 |
self.api_key = api_key
|
142 |
self.market = region
|
143 |
self.proxy = kwargs.get('proxy')
|
|
|
151 |
return self._parse_response(response)
|
152 |
except Exception as e:
|
153 |
logging.exception(str(e))
|
154 |
+
warnings.warn(
|
155 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
156 |
time.sleep(random.randint(2, 5))
|
157 |
+
raise Exception(
|
158 |
+
'Failed to get search results from Bing Search after retries.')
|
159 |
|
160 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
161 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
|
|
165 |
return self._parse_response(response)
|
166 |
except Exception as e:
|
167 |
logging.exception(str(e))
|
168 |
+
warnings.warn(
|
169 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
170 |
await asyncio.sleep(random.randint(2, 5))
|
171 |
+
raise Exception(
|
172 |
+
'Failed to get search results from Bing Search after retries.')
|
173 |
|
174 |
def _call_bing_api(self, query: str) -> dict:
|
175 |
endpoint = 'https://api.bing.microsoft.com/v7.0/search'
|
176 |
params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
|
177 |
headers = {'Ocp-Apim-Subscription-Key': self.api_key}
|
178 |
+
response = requests.get(
|
179 |
+
endpoint, headers=headers, params=params, proxies=self.proxy)
|
180 |
response.raise_for_status()
|
181 |
return response.json()
|
182 |
|
|
|
186 |
headers = {'Ocp-Apim-Subscription-Key': self.api_key}
|
187 |
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
188 |
async with session.get(
|
189 |
+
endpoint,
|
190 |
+
headers=headers,
|
191 |
+
params=params,
|
192 |
+
proxy=self.proxy and
|
193 |
+
(self.proxy.get('http') or self.proxy.get('https'))) as resp:
|
194 |
return await resp.json()
|
195 |
|
196 |
def _parse_response(self, response: dict) -> dict:
|
197 |
+
webpages = {
|
198 |
+
w['id']: w
|
199 |
+
for w in response.get('webPages', {}).get('value', [])
|
200 |
+
}
|
201 |
raw_results = []
|
202 |
|
203 |
+
for item in response.get('rankingResponse',
|
204 |
+
{}).get('mainline', {}).get('items', []):
|
205 |
if item['answerType'] == 'WebPages':
|
206 |
webpage = webpages.get(item['value']['id'])
|
207 |
if webpage:
|
208 |
+
raw_results.append(
|
209 |
+
(webpage['url'], webpage['snippet'], webpage['name']))
|
210 |
+
elif item['answerType'] == 'News' and item['value'][
|
211 |
+
'id'] == response.get('news', {}).get('id'):
|
212 |
for news in response.get('news', {}).get('value', []):
|
213 |
+
raw_results.append(
|
214 |
+
(news['url'], news['description'], news['name']))
|
215 |
|
216 |
return self._filter_results(raw_results)
|
217 |
|
|
|
230 |
topk (int): The number of search results returned in response from API search results.
|
231 |
region (str): The country code string. Specifies the country where the search results come from.
|
232 |
language (str): The language code string. Specifies the preferred language for the search results.
|
233 |
+
extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results.
|
|
|
234 |
**kwargs: Any other parameters related to the Brave Search API. Find more details at
|
235 |
https://api.search.brave.com/app/documentation/web-search/get-started.
|
236 |
"""
|
237 |
|
238 |
+
def __init__(self,
|
239 |
+
api_key: str,
|
240 |
+
region: str = 'ALL',
|
241 |
+
language: str = 'zh-hans',
|
242 |
+
extra_snippests: bool = True,
|
243 |
+
topk: int = 3,
|
244 |
+
black_list: List[str] = [
|
245 |
+
'enoN',
|
246 |
+
'youtube.com',
|
247 |
+
'bilibili.com',
|
248 |
+
'researchgate.net',
|
249 |
+
],
|
250 |
+
**kwargs):
|
|
|
|
|
251 |
self.api_key = api_key
|
252 |
self.market = region
|
253 |
self.proxy = kwargs.get('proxy')
|
|
|
265 |
return self._parse_response(response)
|
266 |
except Exception as e:
|
267 |
logging.exception(str(e))
|
268 |
+
warnings.warn(
|
269 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
270 |
time.sleep(random.randint(2, 5))
|
271 |
+
raise Exception(
|
272 |
+
'Failed to get search results from Brave Search after retries.')
|
273 |
|
274 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
275 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
|
|
279 |
return self._parse_response(response)
|
280 |
except Exception as e:
|
281 |
logging.exception(str(e))
|
282 |
+
warnings.warn(
|
283 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
284 |
await asyncio.sleep(random.randint(2, 5))
|
285 |
+
raise Exception(
|
286 |
+
'Failed to get search results from Brave Search after retries.')
|
287 |
|
288 |
def _call_brave_api(self, query: str) -> dict:
|
289 |
endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
|
|
|
293 |
'search_lang': self.language,
|
294 |
'extra_snippets': self.extra_snippests,
|
295 |
'count': self.topk,
|
296 |
+
**{
|
297 |
+
key: value
|
298 |
+
for key, value in self.kwargs.items() if value is not None
|
299 |
+
},
|
300 |
}
|
301 |
+
headers = {
|
302 |
+
'X-Subscription-Token': self.api_key or '',
|
303 |
+
'Accept': 'application/json'
|
304 |
+
}
|
305 |
+
response = requests.get(
|
306 |
+
endpoint, headers=headers, params=params, proxies=self.proxy)
|
307 |
response.raise_for_status()
|
308 |
return response.json()
|
309 |
|
|
|
315 |
'search_lang': self.language,
|
316 |
'extra_snippets': self.extra_snippests,
|
317 |
'count': self.topk,
|
318 |
+
**{
|
319 |
+
key: value
|
320 |
+
for key, value in self.kwargs.items() if value is not None
|
321 |
+
},
|
322 |
+
}
|
323 |
+
headers = {
|
324 |
+
'X-Subscription-Token': self.api_key or '',
|
325 |
+
'Accept': 'application/json'
|
326 |
}
|
|
|
327 |
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
328 |
async with session.get(
|
329 |
+
endpoint,
|
330 |
+
headers=headers,
|
331 |
+
params=params,
|
332 |
+
proxy=self.proxy and
|
333 |
+
(self.proxy.get('http') or self.proxy.get('https'))) as resp:
|
334 |
return await resp.json()
|
335 |
|
336 |
def _parse_response(self, response: dict) -> dict:
|
|
|
341 |
raw_results = []
|
342 |
|
343 |
for item in filtered_result:
|
344 |
+
raw_results.append((
|
345 |
+
item.get('url', ''),
|
346 |
+
' '.join(
|
347 |
+
filter(None, [
|
348 |
+
item.get('description'),
|
349 |
+
*item.get('extra_snippets', [])
|
350 |
+
])),
|
351 |
+
item.get('title', ''),
|
352 |
+
))
|
353 |
return self._filter_results(raw_results)
|
354 |
|
355 |
|
|
|
376 |
'search': 'organic',
|
377 |
}
|
378 |
|
379 |
+
def __init__(self,
|
380 |
+
api_key: str,
|
381 |
+
topk: int = 3,
|
382 |
+
black_list: List[str] = [
|
383 |
+
'enoN',
|
384 |
+
'youtube.com',
|
385 |
+
'bilibili.com',
|
386 |
+
'researchgate.net',
|
387 |
+
],
|
388 |
+
**kwargs):
|
|
|
|
|
389 |
self.api_key = api_key
|
390 |
self.proxy = kwargs.get('proxy')
|
391 |
self.search_type = kwargs.get('search_type', 'search')
|
|
|
400 |
return self._parse_response(response)
|
401 |
except Exception as e:
|
402 |
logging.exception(str(e))
|
403 |
+
warnings.warn(
|
404 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
405 |
time.sleep(random.randint(2, 5))
|
406 |
+
raise Exception(
|
407 |
+
'Failed to get search results from Google Serper Search after retries.'
|
408 |
+
)
|
409 |
|
410 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
411 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
|
|
415 |
return self._parse_response(response)
|
416 |
except Exception as e:
|
417 |
logging.exception(str(e))
|
418 |
+
warnings.warn(
|
419 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
420 |
await asyncio.sleep(random.randint(2, 5))
|
421 |
+
raise Exception(
|
422 |
+
'Failed to get search results from Google Serper Search after retries.'
|
423 |
+
)
|
424 |
|
425 |
def _call_serper_api(self, query: str) -> dict:
|
426 |
endpoint = f'https://google.serper.dev/{self.search_type}'
|
427 |
params = {
|
428 |
'q': query,
|
429 |
'num': self.topk,
|
430 |
+
**{
|
431 |
+
key: value
|
432 |
+
for key, value in self.kwargs.items() if value is not None
|
433 |
+
},
|
434 |
}
|
435 |
+
headers = {
|
436 |
+
'X-API-KEY': self.api_key or '',
|
437 |
+
'Content-Type': 'application/json'
|
438 |
+
}
|
439 |
+
response = requests.get(
|
440 |
+
endpoint, headers=headers, params=params, proxies=self.proxy)
|
441 |
response.raise_for_status()
|
442 |
return response.json()
|
443 |
|
|
|
446 |
params = {
|
447 |
'q': query,
|
448 |
'num': self.topk,
|
449 |
+
**{
|
450 |
+
key: value
|
451 |
+
for key, value in self.kwargs.items() if value is not None
|
452 |
+
},
|
453 |
+
}
|
454 |
+
headers = {
|
455 |
+
'X-API-KEY': self.api_key or '',
|
456 |
+
'Content-Type': 'application/json'
|
457 |
}
|
|
|
458 |
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
459 |
async with session.get(
|
460 |
+
endpoint,
|
461 |
+
headers=headers,
|
462 |
+
params=params,
|
463 |
+
proxy=self.proxy and
|
464 |
+
(self.proxy.get('http') or self.proxy.get('https'))) as resp:
|
465 |
return await resp.json()
|
466 |
|
467 |
def _parse_response(self, response: dict) -> dict:
|
|
|
472 |
if answer_box.get('answer'):
|
473 |
raw_results.append(('', answer_box.get('answer'), ''))
|
474 |
elif answer_box.get('snippet'):
|
475 |
+
raw_results.append(
|
476 |
+
('', answer_box.get('snippet').replace('\n', ' '), ''))
|
477 |
elif answer_box.get('snippetHighlighted'):
|
478 |
+
raw_results.append(
|
479 |
+
('', answer_box.get('snippetHighlighted'), ''))
|
480 |
|
481 |
if response.get('knowledgeGraph'):
|
482 |
kg = response.get('knowledgeGraph', {})
|
483 |
description = kg.get('description', '')
|
484 |
+
attributes = '. '.join(
|
485 |
+
f'{attribute}: {value}'
|
486 |
+
for attribute, value in kg.get('attributes', {}).items())
|
487 |
raw_results.append(
|
488 |
+
(kg.get('descriptionLink', ''),
|
489 |
+
f'{description}. {attributes}' if attributes else description,
|
490 |
+
f"{kg.get('title', '')}: {kg.get('type', '')}."))
|
491 |
+
|
492 |
+
for result in response[self.result_key_for_type[
|
493 |
+
self.search_type]][:self.topk]:
|
|
|
|
|
494 |
description = result.get('snippet', '')
|
495 |
attributes = '. '.join(
|
496 |
+
f'{attribute}: {value}'
|
497 |
+
for attribute, value in result.get('attributes', {}).items())
|
498 |
raw_results.append(
|
499 |
+
(result.get('link', ''),
|
500 |
+
f'{description}. {attributes}' if attributes else description,
|
501 |
+
result.get('title', '')))
|
|
|
|
|
|
|
502 |
|
503 |
return self._filter_results(raw_results)
|
504 |
|
|
|
529 |
Supports multiple values separated by commas. Example: `30010255`.
|
530 |
"""
|
531 |
|
532 |
+
def __init__(self,
|
533 |
+
secret_id: str = 'Your SecretId',
|
534 |
+
secret_key: str = 'Your SecretKey',
|
535 |
+
api_key: str = '',
|
536 |
+
action: str = 'SearchCommon',
|
537 |
+
version: str = '2020-12-29',
|
538 |
+
service: str = 'tms',
|
539 |
+
host: str = 'tms.tencentcloudapi.com',
|
540 |
+
topk: int = 3,
|
541 |
+
tsn: int = None,
|
542 |
+
insite: str = None,
|
543 |
+
category: str = None,
|
544 |
+
vrid: str = None,
|
545 |
+
black_list: List[str] = [
|
546 |
+
'enoN',
|
547 |
+
'youtube.com',
|
548 |
+
'bilibili.com',
|
549 |
+
'researchgate.net',
|
550 |
+
]):
|
|
|
|
|
551 |
self.secret_id = secret_id
|
552 |
self.secret_key = secret_key
|
553 |
self.api_key = api_key
|
|
|
569 |
return self._parse_response(response)
|
570 |
except Exception as e:
|
571 |
logging.exception(str(e))
|
572 |
+
warnings.warn(
|
573 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
574 |
time.sleep(random.randint(2, 5))
|
575 |
+
raise Exception(
|
576 |
+
'Failed to get search results from Bing Search after retries.')
|
577 |
|
578 |
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
579 |
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
|
|
583 |
return self._parse_response(response)
|
584 |
except Exception as e:
|
585 |
logging.exception(str(e))
|
586 |
+
warnings.warn(
|
587 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
588 |
await asyncio.sleep(random.randint(2, 5))
|
589 |
+
raise Exception(
|
590 |
+
'Failed to get search results from Bing Search after retries.')
|
591 |
|
592 |
def _get_headers_and_payload(self, query: str) -> tuple:
|
593 |
|
|
|
617 |
ct = 'application/json; charset=utf-8'
|
618 |
canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
|
619 |
signed_headers = 'content-type;host;x-tc-action'
|
620 |
+
hashed_request_payload = hashlib.sha256(
|
621 |
+
payload.encode('utf-8')).hexdigest()
|
622 |
canonical_request = (
|
623 |
+
http_request_method + '\n' + canonical_uri + '\n' +
|
624 |
+
canonical_querystring + '\n' + canonical_headers + '\n' +
|
625 |
+
signed_headers + '\n' + hashed_request_payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
|
627 |
# ************* 步骤 2:拼接待签名字符串 *************
|
628 |
credential_scope = date + '/' + self.service + '/' + 'tc3_request'
|
629 |
+
hashed_canonical_request = hashlib.sha256(
|
630 |
+
canonical_request.encode('utf-8')).hexdigest()
|
631 |
+
string_to_sign = (
|
632 |
+
algorithm + '\n' + str(timestamp) + '\n' + credential_scope +
|
633 |
+
'\n' + hashed_canonical_request)
|
634 |
|
635 |
# ************* 步骤 3:计算签名 *************
|
636 |
secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
|
637 |
secret_service = sign(secret_date, self.service)
|
638 |
secret_signing = sign(secret_service, 'tc3_request')
|
639 |
+
signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'),
|
640 |
+
hashlib.sha256).hexdigest()
|
641 |
|
642 |
# ************* 步骤 4:拼接 Authorization *************
|
643 |
authorization = (
|
644 |
+
algorithm + ' ' + 'Credential=' + self.secret_id + '/' +
|
645 |
+
credential_scope + ', ' + 'SignedHeaders=' + signed_headers +
|
646 |
+
', ' + 'Signature=' + signature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
|
648 |
# ************* 步骤 5:构造并发起请求 *************
|
649 |
headers = {
|
|
|
652 |
'Host': self.host,
|
653 |
'X-TC-Action': self.action,
|
654 |
'X-TC-Timestamp': str(timestamp),
|
655 |
+
'X-TC-Version': self.version
|
656 |
}
|
657 |
# if self.region:
|
658 |
# headers["X-TC-Region"] = self.region
|
|
|
670 |
except Exception as e:
|
671 |
logging.warning(str(e))
|
672 |
import ast
|
|
|
673 |
resp = ast.literal_eval(resp)
|
674 |
return resp.get('Response', dict())
|
675 |
|
676 |
async def _async_call_tencent_api(self, query: str):
|
677 |
headers, payload = self._get_headers_and_payload(query)
|
678 |
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
679 |
+
async with session.post(
|
680 |
+
'https://' + self.host.lstrip('/'),
|
681 |
+
headers=headers,
|
682 |
+
data=payload) as resp:
|
683 |
return (await resp.json()).get('Response', {})
|
684 |
|
685 |
def _parse_response(self, response: dict) -> dict:
|
|
|
688 |
display = json.loads(item['Display'])
|
689 |
if not display['url']:
|
690 |
continue
|
691 |
+
raw_results.append((display['url'], display['content']
|
692 |
+
or display['abstract_info'], display['title']))
|
693 |
return self._filter_results(raw_results)
|
694 |
|
695 |
|
|
|
715 |
async def afetch(self, url: str) -> Tuple[bool, str]:
|
716 |
try:
|
717 |
async with aiohttp.ClientSession(
|
718 |
+
raise_for_status=True,
|
719 |
+
timeout=aiohttp.ClientTimeout(self.timeout)) as session:
|
720 |
async with session.get(url) as resp:
|
721 |
html = await resp.text(errors='ignore')
|
722 |
text = BeautifulSoup(html, 'html.parser').get_text()
|
|
|
727 |
|
728 |
|
729 |
class WebBrowser(BaseAction):
|
730 |
+
"""Wrapper around the Web Browser Tool.
|
731 |
+
"""
|
732 |
+
|
733 |
+
def __init__(self,
|
734 |
+
searcher_type: str = 'DuckDuckGoSearch',
|
735 |
+
timeout: int = 5,
|
736 |
+
black_list: Optional[List[str]] = [
|
737 |
+
'enoN',
|
738 |
+
'youtube.com',
|
739 |
+
'bilibili.com',
|
740 |
+
'researchgate.net',
|
741 |
+
],
|
742 |
+
topk: int = 20,
|
743 |
+
description: Optional[dict] = None,
|
744 |
+
parser: Type[BaseParser] = JsonParser,
|
745 |
+
**kwargs):
|
746 |
+
self.searcher = eval(searcher_type)(
|
747 |
+
black_list=black_list, topk=topk, **kwargs)
|
748 |
self.fetcher = ContentFetcher(timeout=timeout)
|
749 |
self.search_results = None
|
750 |
super().__init__(description, parser)
|
|
|
759 |
search_results = {}
|
760 |
|
761 |
with ThreadPoolExecutor() as executor:
|
762 |
+
future_to_query = {
|
763 |
+
executor.submit(self.searcher.search, q): q
|
764 |
+
for q in queries
|
765 |
+
}
|
766 |
|
767 |
for future in as_completed(future_to_query):
|
768 |
query = future_to_query[future]
|
|
|
775 |
if result['url'] not in search_results:
|
776 |
search_results[result['url']] = result
|
777 |
else:
|
778 |
+
search_results[
|
779 |
+
result['url']]['summ'] += f"\n{result['summ']}"
|
780 |
|
781 |
+
self.search_results = {
|
782 |
+
idx: result
|
783 |
+
for idx, result in enumerate(search_results.values())
|
784 |
+
}
|
785 |
return self.search_results
|
786 |
|
787 |
@tool_api
|
|
|
798 |
with ThreadPoolExecutor() as executor:
|
799 |
future_to_id = {
|
800 |
executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id
|
801 |
+
for select_id in select_ids if select_id in self.search_results
|
|
|
802 |
}
|
803 |
for future in as_completed(future_to_id):
|
804 |
select_id = future_to_id[future]
|
|
|
808 |
warnings.warn(f'{select_id} generated an exception: {exc}')
|
809 |
else:
|
810 |
if web_success:
|
811 |
+
self.search_results[select_id][
|
812 |
+
'content'] = web_content[:8192]
|
813 |
+
new_search_results[select_id] = self.search_results[
|
814 |
+
select_id].copy()
|
815 |
new_search_results[select_id].pop('summ')
|
816 |
|
817 |
return new_search_results
|
|
|
827 |
|
828 |
|
829 |
class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
|
830 |
+
"""Wrapper around the Web Browser Tool.
|
831 |
+
"""
|
832 |
|
833 |
@tool_api
|
834 |
async def search(self, query: Union[str, List[str]]) -> dict:
|
835 |
"""BING search API
|
836 |
+
|
837 |
Args:
|
838 |
query (List[str]): list of search query strings
|
839 |
"""
|
|
|
856 |
if result['url'] not in search_results:
|
857 |
search_results[result['url']] = result
|
858 |
else:
|
859 |
+
search_results[
|
860 |
+
result['url']]['summ'] += f"\n{result['summ']}"
|
861 |
|
862 |
+
self.search_results = {
|
863 |
+
idx: result
|
864 |
+
for idx, result in enumerate(search_results.values())
|
865 |
+
}
|
866 |
return self.search_results
|
867 |
|
868 |
@tool_api
|
|
|
879 |
tasks = []
|
880 |
for select_id in select_ids:
|
881 |
if select_id in self.search_results:
|
882 |
+
task = asyncio.create_task(
|
883 |
+
self.fetcher.afetch(self.search_results[select_id]['url']))
|
884 |
task.select_id = select_id
|
885 |
tasks.append(task)
|
886 |
async for future in async_as_completed(tasks):
|
|
|
891 |
warnings.warn(f'{select_id} generated an exception: {exc}')
|
892 |
else:
|
893 |
if web_success:
|
894 |
+
self.search_results[select_id][
|
895 |
+
'content'] = web_content[:8192]
|
896 |
+
new_search_results[select_id] = self.search_results[
|
897 |
+
select_id].copy()
|
898 |
new_search_results[select_id].pop('summ')
|
899 |
return new_search_results
|
900 |
|
lagent/agents/__init__.py
CHANGED
@@ -1,33 +1,9 @@
|
|
1 |
-
from .agent import
|
2 |
-
Agent,
|
3 |
-
AgentDict,
|
4 |
-
AgentList,
|
5 |
-
AsyncAgent,
|
6 |
-
AsyncSequential,
|
7 |
-
AsyncStreamingAgent,
|
8 |
-
AsyncStreamingSequential,
|
9 |
-
Sequential,
|
10 |
-
StreamingAgent,
|
11 |
-
StreamingSequential,
|
12 |
-
)
|
13 |
from .react import AsyncReAct, ReAct
|
14 |
from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder
|
15 |
|
16 |
__all__ = [
|
17 |
-
'Agent',
|
18 |
-
'
|
19 |
-
'
|
20 |
-
'AsyncAgent',
|
21 |
-
'AgentForInternLM',
|
22 |
-
'AsyncAgentForInternLM',
|
23 |
-
'MathCoder',
|
24 |
-
'AsyncMathCoder',
|
25 |
-
'ReAct',
|
26 |
-
'AsyncReAct',
|
27 |
-
'Sequential',
|
28 |
-
'AsyncSequential',
|
29 |
-
'StreamingAgent',
|
30 |
-
'StreamingSequential',
|
31 |
-
'AsyncStreamingAgent',
|
32 |
-
'AsyncStreamingSequential',
|
33 |
]
|
|
|
1 |
+
from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from .react import AsyncReAct, ReAct
|
3 |
from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder
|
4 |
|
5 |
__all__ = [
|
6 |
+
'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
|
7 |
+
'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
|
8 |
+
'AsyncReAct', 'Sequential', 'AsyncSequential'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
]
|
lagent/agents/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (547 Bytes). View file
|
|
lagent/agents/__pycache__/agent.cpython-310.pyc
ADDED
Binary file (12.9 kB). View file
|
|
lagent/agents/__pycache__/react.cpython-310.pyc
ADDED
Binary file (4.85 kB). View file
|
|
lagent/agents/__pycache__/stream.cpython-310.pyc
ADDED
Binary file (8.95 kB). View file
|
|
lagent/agents/agent.py
CHANGED
@@ -3,7 +3,7 @@ import warnings
|
|
3 |
from collections import OrderedDict, UserDict, UserList, abc
|
4 |
from functools import wraps
|
5 |
from itertools import chain, repeat
|
6 |
-
from typing import Any,
|
7 |
|
8 |
from lagent.agents.aggregator import DefaultAggregator
|
9 |
from lagent.hooks import Hook, RemovableHandle
|
@@ -11,7 +11,7 @@ from lagent.llms import BaseLLM
|
|
11 |
from lagent.memory import Memory, MemoryManager
|
12 |
from lagent.prompts.parsers import StrParser
|
13 |
from lagent.prompts.prompt_template import PromptTemplate
|
14 |
-
from lagent.schema import AgentMessage
|
15 |
from lagent.utils import create_object
|
16 |
|
17 |
|
@@ -63,17 +63,29 @@ class Agent:
|
|
63 |
if self.memory:
|
64 |
self.memory.add(message, session_id=session_id)
|
65 |
|
66 |
-
def __call__(
|
|
|
|
|
|
|
|
|
|
|
67 |
# message.receiver = self.name
|
68 |
-
message = [
|
|
|
|
|
|
|
69 |
for hook in self._hooks.values():
|
70 |
result = hook.before_agent(self, message, session_id)
|
71 |
if result:
|
72 |
message = result
|
73 |
self.update_memory(message, session_id=session_id)
|
74 |
-
response_message = self.forward(
|
|
|
75 |
if not isinstance(response_message, AgentMessage):
|
76 |
-
response_message = AgentMessage(
|
|
|
|
|
|
|
77 |
self.update_memory(response_message, session_id=session_id)
|
78 |
response_message = copy.deepcopy(response_message)
|
79 |
for hook in self._hooks.values():
|
@@ -82,14 +94,25 @@ class Agent:
|
|
82 |
response_message = result
|
83 |
return response_message
|
84 |
|
85 |
-
def forward(self,
|
|
|
|
|
|
|
86 |
formatted_messages = self.aggregator.aggregate(
|
87 |
-
self.memory.get(session_id),
|
|
|
|
|
|
|
88 |
)
|
89 |
llm_response = self.llm.chat(formatted_messages, **kwargs)
|
90 |
if self.output_format:
|
91 |
-
formatted_messages = self.output_format.parse_response(
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
93 |
return llm_response
|
94 |
|
95 |
def __setattr__(self, __name: str, __value: Any) -> None:
|
@@ -142,8 +165,12 @@ class Agent:
|
|
142 |
self._hooks[handle.id] = hook
|
143 |
return handle
|
144 |
|
145 |
-
def reset(self,
|
146 |
-
|
|
|
|
|
|
|
|
|
147 |
if keypath:
|
148 |
keys, agent = keypath.split('.'), self
|
149 |
for key in keys:
|
@@ -162,13 +189,15 @@ class Agent:
|
|
162 |
def __repr__(self):
|
163 |
|
164 |
def _rcsv_repr(agent, n_indent=1):
|
165 |
-
res = agent.__class__.__name__ + (f"(name='{agent.name}')"
|
|
|
166 |
modules = [
|
167 |
f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
|
168 |
for name, agent in getattr(agent, '_agents', {}).items()
|
169 |
]
|
170 |
if modules:
|
171 |
-
res += '(\n' + '\n'.join(
|
|
|
172 |
elif not res.endswith(')'):
|
173 |
res += '()'
|
174 |
return res
|
@@ -176,18 +205,28 @@ class Agent:
|
|
176 |
return _rcsv_repr(self)
|
177 |
|
178 |
|
179 |
-
class
|
180 |
|
181 |
-
async def __call__(self,
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
for hook in self._hooks.values():
|
184 |
result = hook.before_agent(self, message, session_id)
|
185 |
if result:
|
186 |
message = result
|
187 |
self.update_memory(message, session_id=session_id)
|
188 |
-
response_message = await self.forward(
|
|
|
189 |
if not isinstance(response_message, AgentMessage):
|
190 |
-
response_message = AgentMessage(
|
|
|
|
|
|
|
191 |
self.update_memory(response_message, session_id=session_id)
|
192 |
response_message = copy.deepcopy(response_message)
|
193 |
for hook in self._hooks.values():
|
@@ -196,133 +235,40 @@ class AsyncAgentMixin:
|
|
196 |
response_message = result
|
197 |
return response_message
|
198 |
|
199 |
-
async def forward(self,
|
|
|
|
|
|
|
200 |
formatted_messages = self.aggregator.aggregate(
|
201 |
-
self.memory.get(session_id),
|
|
|
|
|
|
|
202 |
)
|
203 |
-
llm_response = await self.llm.chat(formatted_messages, session_id,
|
|
|
204 |
if self.output_format:
|
205 |
-
formatted_messages = self.output_format.parse_response(
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
"""Asynchronous variant of the Agent class"""
|
212 |
-
|
213 |
-
pass
|
214 |
-
|
215 |
-
|
216 |
-
class StreamingAgentMixin:
|
217 |
-
"""Component that makes agent calling output a streaming response."""
|
218 |
-
|
219 |
-
def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Generator[AgentMessage, None, None]:
|
220 |
-
message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
|
221 |
-
for hook in self._hooks.values():
|
222 |
-
result = hook.before_agent(self, message, session_id)
|
223 |
-
if result:
|
224 |
-
message = result
|
225 |
-
self.update_memory(message, session_id=session_id)
|
226 |
-
response_message = AgentMessage(sender=self.name, content="")
|
227 |
-
for response_message in self.forward(*message, session_id=session_id, **kwargs):
|
228 |
-
if not isinstance(response_message, AgentMessage):
|
229 |
-
model_state, response = response_message
|
230 |
-
response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state)
|
231 |
-
yield response_message.model_copy()
|
232 |
-
self.update_memory(response_message, session_id=session_id)
|
233 |
-
response_message = copy.deepcopy(response_message)
|
234 |
-
for hook in self._hooks.values():
|
235 |
-
result = hook.after_agent(self, response_message, session_id)
|
236 |
-
if result:
|
237 |
-
response_message = result
|
238 |
-
yield response_message
|
239 |
-
|
240 |
-
def forward(
|
241 |
-
self, *message: AgentMessage, session_id=0, **kwargs
|
242 |
-
) -> Generator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None, None]:
|
243 |
-
formatted_messages = self.aggregator.aggregate(
|
244 |
-
self.memory.get(session_id), self.name, self.output_format, self.template
|
245 |
-
)
|
246 |
-
for model_state, response, *_ in self.llm.stream_chat(formatted_messages, session_id=session_id, **kwargs):
|
247 |
-
yield (
|
248 |
-
AgentMessage(
|
249 |
-
sender=self.name,
|
250 |
-
content=response,
|
251 |
-
formatted=self.output_format.parse_response(response),
|
252 |
-
stream_state=model_state,
|
253 |
-
)
|
254 |
-
if self.output_format
|
255 |
-
else (model_state, response)
|
256 |
-
)
|
257 |
-
|
258 |
-
|
259 |
-
class AsyncStreamingAgentMixin:
|
260 |
-
"""Component that makes asynchronous agent calling output a streaming response."""
|
261 |
-
|
262 |
-
async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AsyncGenerator[AgentMessage, None]:
|
263 |
-
message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
|
264 |
-
for hook in self._hooks.values():
|
265 |
-
result = hook.before_agent(self, message, session_id)
|
266 |
-
if result:
|
267 |
-
message = result
|
268 |
-
self.update_memory(message, session_id=session_id)
|
269 |
-
response_message = AgentMessage(sender=self.name, content="")
|
270 |
-
async for response_message in self.forward(*message, session_id=session_id, **kwargs):
|
271 |
-
if not isinstance(response_message, AgentMessage):
|
272 |
-
model_state, response = response_message
|
273 |
-
response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state)
|
274 |
-
yield response_message.model_copy()
|
275 |
-
self.update_memory(response_message, session_id=session_id)
|
276 |
-
response_message = copy.deepcopy(response_message)
|
277 |
-
for hook in self._hooks.values():
|
278 |
-
result = hook.after_agent(self, response_message, session_id)
|
279 |
-
if result:
|
280 |
-
response_message = result
|
281 |
-
yield response_message
|
282 |
-
|
283 |
-
async def forward(
|
284 |
-
self, *message: AgentMessage, session_id=0, **kwargs
|
285 |
-
) -> AsyncGenerator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None]:
|
286 |
-
formatted_messages = self.aggregator.aggregate(
|
287 |
-
self.memory.get(session_id), self.name, self.output_format, self.template
|
288 |
-
)
|
289 |
-
async for model_state, response, *_ in self.llm.stream_chat(
|
290 |
-
formatted_messages, session_id=session_id, **kwargs
|
291 |
-
):
|
292 |
-
yield (
|
293 |
-
AgentMessage(
|
294 |
-
sender=self.name,
|
295 |
-
content=response,
|
296 |
-
formatted=self.output_format.parse_response(response),
|
297 |
-
stream_state=model_state,
|
298 |
-
)
|
299 |
-
if self.output_format
|
300 |
-
else (model_state, response)
|
301 |
)
|
302 |
-
|
303 |
-
|
304 |
-
class StreamingAgent(StreamingAgentMixin, Agent):
|
305 |
-
"""Streaming variant of the Agent class"""
|
306 |
-
|
307 |
-
pass
|
308 |
-
|
309 |
-
|
310 |
-
class AsyncStreamingAgent(AsyncStreamingAgentMixin, Agent):
|
311 |
-
"""Streaming variant of the AsyncAgent class"""
|
312 |
-
|
313 |
-
pass
|
314 |
|
315 |
|
316 |
class Sequential(Agent):
|
317 |
-
"""Sequential is an agent container that forwards messages to each agent
|
318 |
in the order they are added."""
|
319 |
|
320 |
-
def __init__(self, *agents: Union[Agent, Iterable], **kwargs):
|
321 |
super().__init__(**kwargs)
|
322 |
self._agents = OrderedDict()
|
323 |
if not agents:
|
324 |
raise ValueError('At least one agent should be provided')
|
325 |
-
if isinstance(agents[0],
|
|
|
326 |
if not agents[0]:
|
327 |
raise ValueError('At least one agent should be provided')
|
328 |
agents = agents[0]
|
@@ -333,11 +279,17 @@ class Sequential(Agent):
|
|
333 |
key, agent = agent
|
334 |
self.add_agent(key, agent)
|
335 |
|
336 |
-
def add_agent(self, name: str, agent: Agent):
|
337 |
-
assert isinstance(
|
|
|
|
|
338 |
self._agents[str(name)] = agent
|
339 |
|
340 |
-
def forward(self,
|
|
|
|
|
|
|
|
|
341 |
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
|
342 |
if exit_at is None:
|
343 |
exit_at = len(self) - 1
|
@@ -345,7 +297,7 @@ class Sequential(Agent):
|
|
345 |
for _ in range(exit_at + 1):
|
346 |
agent = next(iterator)
|
347 |
if isinstance(message, AgentMessage):
|
348 |
-
message = (message,)
|
349 |
message = agent(*message, session_id=session_id, **kwargs)
|
350 |
return message
|
351 |
|
@@ -359,11 +311,13 @@ class Sequential(Agent):
|
|
359 |
return len(self._agents)
|
360 |
|
361 |
|
362 |
-
class AsyncSequential(
|
363 |
|
364 |
-
async def forward(
|
365 |
-
|
366 |
-
|
|
|
|
|
367 |
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
|
368 |
if exit_at is None:
|
369 |
exit_at = len(self) - 1
|
@@ -371,43 +325,11 @@ class AsyncSequential(AsyncAgentMixin, Sequential):
|
|
371 |
for _ in range(exit_at + 1):
|
372 |
agent = next(iterator)
|
373 |
if isinstance(message, AgentMessage):
|
374 |
-
message = (message,)
|
375 |
message = await agent(*message, session_id=session_id, **kwargs)
|
376 |
return message
|
377 |
|
378 |
|
379 |
-
class StreamingSequential(StreamingAgentMixin, Sequential):
|
380 |
-
"""Streaming variant of the Sequential class"""
|
381 |
-
|
382 |
-
def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs):
|
383 |
-
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
|
384 |
-
if exit_at is None:
|
385 |
-
exit_at = len(self) - 1
|
386 |
-
iterator = chain.from_iterable(repeat(self._agents.values()))
|
387 |
-
for _ in range(exit_at + 1):
|
388 |
-
agent = next(iterator)
|
389 |
-
if isinstance(message, AgentMessage):
|
390 |
-
message = (message,)
|
391 |
-
for message in agent(*message, session_id=session_id, **kwargs):
|
392 |
-
yield message
|
393 |
-
|
394 |
-
|
395 |
-
class AsyncStreamingSequential(AsyncStreamingAgentMixin, Sequential):
|
396 |
-
"""Streaming variant of the AsyncSequential class"""
|
397 |
-
|
398 |
-
async def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs):
|
399 |
-
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
|
400 |
-
if exit_at is None:
|
401 |
-
exit_at = len(self) - 1
|
402 |
-
iterator = chain.from_iterable(repeat(self._agents.values()))
|
403 |
-
for _ in range(exit_at + 1):
|
404 |
-
agent = next(iterator)
|
405 |
-
if isinstance(message, AgentMessage):
|
406 |
-
message = (message,)
|
407 |
-
async for message in agent(*message, session_id=session_id, **kwargs):
|
408 |
-
yield message
|
409 |
-
|
410 |
-
|
411 |
class AgentContainerMixin:
|
412 |
|
413 |
def __init_subclass__(cls):
|
@@ -427,28 +349,33 @@ class AgentContainerMixin:
|
|
427 |
|
428 |
ret = func(self, *args, **kwargs)
|
429 |
agents = OrderedDict()
|
430 |
-
for k, item in self.data.items() if isinstance(
|
431 |
-
|
|
|
|
|
432 |
_backup(data)
|
433 |
-
raise KeyError(
|
|
|
434 |
if isinstance(k, str) and '.' in k:
|
435 |
_backup(data)
|
436 |
-
raise KeyError(
|
437 |
-
|
|
|
438 |
_backup(data)
|
439 |
-
raise TypeError(
|
|
|
|
|
440 |
agents[str(k)] = item
|
441 |
self._agents = agents
|
442 |
return ret
|
443 |
|
444 |
return wrapped_func
|
445 |
|
446 |
-
# fmt: off
|
447 |
for method in [
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
]:
|
453 |
if hasattr(cls, method):
|
454 |
setattr(cls, method, wrap_api(getattr(cls, method)))
|
@@ -456,7 +383,8 @@ class AgentContainerMixin:
|
|
456 |
|
457 |
class AgentList(Agent, UserList, AgentContainerMixin):
|
458 |
|
459 |
-
def __init__(self,
|
|
|
460 |
Agent.__init__(self, memory=None)
|
461 |
UserList.__init__(self, agents)
|
462 |
self.name = None
|
@@ -464,7 +392,9 @@ class AgentList(Agent, UserList, AgentContainerMixin):
|
|
464 |
|
465 |
class AgentDict(Agent, UserDict, AgentContainerMixin):
|
466 |
|
467 |
-
def __init__(self,
|
|
|
|
|
468 |
Agent.__init__(self, memory=None)
|
469 |
UserDict.__init__(self, agents)
|
470 |
self.name = None
|
|
|
3 |
from collections import OrderedDict, UserDict, UserList, abc
|
4 |
from functools import wraps
|
5 |
from itertools import chain, repeat
|
6 |
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
|
7 |
|
8 |
from lagent.agents.aggregator import DefaultAggregator
|
9 |
from lagent.hooks import Hook, RemovableHandle
|
|
|
11 |
from lagent.memory import Memory, MemoryManager
|
12 |
from lagent.prompts.parsers import StrParser
|
13 |
from lagent.prompts.prompt_template import PromptTemplate
|
14 |
+
from lagent.schema import AgentMessage
|
15 |
from lagent.utils import create_object
|
16 |
|
17 |
|
|
|
63 |
if self.memory:
|
64 |
self.memory.add(message, session_id=session_id)
|
65 |
|
66 |
+
def __call__(
|
67 |
+
self,
|
68 |
+
*message: Union[str, AgentMessage, List[AgentMessage]],
|
69 |
+
session_id=0,
|
70 |
+
**kwargs,
|
71 |
+
) -> AgentMessage:
|
72 |
# message.receiver = self.name
|
73 |
+
message = [
|
74 |
+
AgentMessage(sender='user', content=m)
|
75 |
+
if isinstance(m, str) else copy.deepcopy(m) for m in message
|
76 |
+
]
|
77 |
for hook in self._hooks.values():
|
78 |
result = hook.before_agent(self, message, session_id)
|
79 |
if result:
|
80 |
message = result
|
81 |
self.update_memory(message, session_id=session_id)
|
82 |
+
response_message = self.forward(
|
83 |
+
*message, session_id=session_id, **kwargs)
|
84 |
if not isinstance(response_message, AgentMessage):
|
85 |
+
response_message = AgentMessage(
|
86 |
+
sender=self.name,
|
87 |
+
content=response_message,
|
88 |
+
)
|
89 |
self.update_memory(response_message, session_id=session_id)
|
90 |
response_message = copy.deepcopy(response_message)
|
91 |
for hook in self._hooks.values():
|
|
|
94 |
response_message = result
|
95 |
return response_message
|
96 |
|
97 |
+
def forward(self,
|
98 |
+
*message: AgentMessage,
|
99 |
+
session_id=0,
|
100 |
+
**kwargs) -> Union[AgentMessage, str]:
|
101 |
formatted_messages = self.aggregator.aggregate(
|
102 |
+
self.memory.get(session_id),
|
103 |
+
self.name,
|
104 |
+
self.output_format,
|
105 |
+
self.template,
|
106 |
)
|
107 |
llm_response = self.llm.chat(formatted_messages, **kwargs)
|
108 |
if self.output_format:
|
109 |
+
formatted_messages = self.output_format.parse_response(
|
110 |
+
llm_response)
|
111 |
+
return AgentMessage(
|
112 |
+
sender=self.name,
|
113 |
+
content=llm_response,
|
114 |
+
formatted=formatted_messages,
|
115 |
+
)
|
116 |
return llm_response
|
117 |
|
118 |
def __setattr__(self, __name: str, __value: Any) -> None:
|
|
|
165 |
self._hooks[handle.id] = hook
|
166 |
return handle
|
167 |
|
168 |
+
def reset(self,
|
169 |
+
session_id=0,
|
170 |
+
keypath: Optional[str] = None,
|
171 |
+
recursive: bool = False):
|
172 |
+
assert not (keypath and
|
173 |
+
recursive), 'keypath and recursive can\'t be used together'
|
174 |
if keypath:
|
175 |
keys, agent = keypath.split('.'), self
|
176 |
for key in keys:
|
|
|
189 |
def __repr__(self):
|
190 |
|
191 |
def _rcsv_repr(agent, n_indent=1):
|
192 |
+
res = agent.__class__.__name__ + (f"(name='{agent.name}')"
|
193 |
+
if agent.name else '')
|
194 |
modules = [
|
195 |
f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
|
196 |
for name, agent in getattr(agent, '_agents', {}).items()
|
197 |
]
|
198 |
if modules:
|
199 |
+
res += '(\n' + '\n'.join(
|
200 |
+
modules) + f'\n{(n_indent - 1) * " "})'
|
201 |
elif not res.endswith(')'):
|
202 |
res += '()'
|
203 |
return res
|
|
|
205 |
return _rcsv_repr(self)
|
206 |
|
207 |
|
208 |
+
class AsyncAgent(Agent):
|
209 |
|
210 |
+
async def __call__(self,
|
211 |
+
*message: AgentMessage | List[AgentMessage],
|
212 |
+
session_id=0,
|
213 |
+
**kwargs) -> AgentMessage:
|
214 |
+
message = [
|
215 |
+
AgentMessage(sender='user', content=m)
|
216 |
+
if isinstance(m, str) else copy.deepcopy(m) for m in message
|
217 |
+
]
|
218 |
for hook in self._hooks.values():
|
219 |
result = hook.before_agent(self, message, session_id)
|
220 |
if result:
|
221 |
message = result
|
222 |
self.update_memory(message, session_id=session_id)
|
223 |
+
response_message = await self.forward(
|
224 |
+
*message, session_id=session_id, **kwargs)
|
225 |
if not isinstance(response_message, AgentMessage):
|
226 |
+
response_message = AgentMessage(
|
227 |
+
sender=self.name,
|
228 |
+
content=response_message,
|
229 |
+
)
|
230 |
self.update_memory(response_message, session_id=session_id)
|
231 |
response_message = copy.deepcopy(response_message)
|
232 |
for hook in self._hooks.values():
|
|
|
235 |
response_message = result
|
236 |
return response_message
|
237 |
|
238 |
+
async def forward(self,
|
239 |
+
*message: AgentMessage,
|
240 |
+
session_id=0,
|
241 |
+
**kwargs) -> Union[AgentMessage, str]:
|
242 |
formatted_messages = self.aggregator.aggregate(
|
243 |
+
self.memory.get(session_id),
|
244 |
+
self.name,
|
245 |
+
self.output_format,
|
246 |
+
self.template,
|
247 |
)
|
248 |
+
llm_response = await self.llm.chat(formatted_messages, session_id,
|
249 |
+
**kwargs)
|
250 |
if self.output_format:
|
251 |
+
formatted_messages = self.output_format.parse_response(
|
252 |
+
llm_response)
|
253 |
+
return AgentMessage(
|
254 |
+
sender=self.name,
|
255 |
+
content=llm_response,
|
256 |
+
formatted=formatted_messages,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
)
|
258 |
+
return llm_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
|
261 |
class Sequential(Agent):
|
262 |
+
"""Sequential is an agent container that forwards messages to each agent
|
263 |
in the order they are added."""
|
264 |
|
265 |
+
def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
|
266 |
super().__init__(**kwargs)
|
267 |
self._agents = OrderedDict()
|
268 |
if not agents:
|
269 |
raise ValueError('At least one agent should be provided')
|
270 |
+
if isinstance(agents[0],
|
271 |
+
Iterable) and not isinstance(agents[0], Agent):
|
272 |
if not agents[0]:
|
273 |
raise ValueError('At least one agent should be provided')
|
274 |
agents = agents[0]
|
|
|
279 |
key, agent = agent
|
280 |
self.add_agent(key, agent)
|
281 |
|
282 |
+
def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]):
|
283 |
+
assert isinstance(
|
284 |
+
agent, (Agent, AsyncAgent
|
285 |
+
)), f'{type(agent)} is not an Agent or AsyncAgent subclass'
|
286 |
self._agents[str(name)] = agent
|
287 |
|
288 |
+
def forward(self,
|
289 |
+
*message: AgentMessage,
|
290 |
+
session_id=0,
|
291 |
+
exit_at: Optional[int] = None,
|
292 |
+
**kwargs) -> AgentMessage:
|
293 |
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
|
294 |
if exit_at is None:
|
295 |
exit_at = len(self) - 1
|
|
|
297 |
for _ in range(exit_at + 1):
|
298 |
agent = next(iterator)
|
299 |
if isinstance(message, AgentMessage):
|
300 |
+
message = (message, )
|
301 |
message = agent(*message, session_id=session_id, **kwargs)
|
302 |
return message
|
303 |
|
|
|
311 |
return len(self._agents)
|
312 |
|
313 |
|
314 |
+
class AsyncSequential(Sequential, AsyncAgent):
|
315 |
|
316 |
+
async def forward(self,
|
317 |
+
*message: AgentMessage,
|
318 |
+
session_id=0,
|
319 |
+
exit_at: Optional[int] = None,
|
320 |
+
**kwargs) -> AgentMessage:
|
321 |
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
|
322 |
if exit_at is None:
|
323 |
exit_at = len(self) - 1
|
|
|
325 |
for _ in range(exit_at + 1):
|
326 |
agent = next(iterator)
|
327 |
if isinstance(message, AgentMessage):
|
328 |
+
message = (message, )
|
329 |
message = await agent(*message, session_id=session_id, **kwargs)
|
330 |
return message
|
331 |
|
332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
class AgentContainerMixin:
|
334 |
|
335 |
def __init_subclass__(cls):
|
|
|
349 |
|
350 |
ret = func(self, *args, **kwargs)
|
351 |
agents = OrderedDict()
|
352 |
+
for k, item in (self.data.items() if isinstance(
|
353 |
+
self.data, abc.Mapping) else enumerate(self.data)):
|
354 |
+
if isinstance(self.data,
|
355 |
+
abc.Mapping) and not isinstance(k, str):
|
356 |
_backup(data)
|
357 |
+
raise KeyError(
|
358 |
+
f'agent name should be a string, got {type(k)}')
|
359 |
if isinstance(k, str) and '.' in k:
|
360 |
_backup(data)
|
361 |
+
raise KeyError(
|
362 |
+
f'agent name can\'t contain ".", got {k}')
|
363 |
+
if not isinstance(item, (Agent, AsyncAgent)):
|
364 |
_backup(data)
|
365 |
+
raise TypeError(
|
366 |
+
f'{type(item)} is not an Agent or AsyncAgent subclass'
|
367 |
+
)
|
368 |
agents[str(k)] = item
|
369 |
self._agents = agents
|
370 |
return ret
|
371 |
|
372 |
return wrapped_func
|
373 |
|
|
|
374 |
for method in [
|
375 |
+
'append', 'sort', 'reverse', 'pop', 'clear', 'update',
|
376 |
+
'insert', 'extend', 'remove', '__init__', '__setitem__',
|
377 |
+
'__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
|
378 |
+
'__imul__', '__rmul__'
|
379 |
]:
|
380 |
if hasattr(cls, method):
|
381 |
setattr(cls, method, wrap_api(getattr(cls, method)))
|
|
|
383 |
|
384 |
class AgentList(Agent, UserList, AgentContainerMixin):
|
385 |
|
386 |
+
def __init__(self,
|
387 |
+
agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
|
388 |
Agent.__init__(self, memory=None)
|
389 |
UserList.__init__(self, agents)
|
390 |
self.name = None
|
|
|
392 |
|
393 |
class AgentDict(Agent, UserDict, AgentContainerMixin):
|
394 |
|
395 |
+
def __init__(self,
|
396 |
+
agents: Optional[Mapping[str, Union[Agent,
|
397 |
+
AsyncAgent]]] = None):
|
398 |
Agent.__init__(self, memory=None)
|
399 |
UserDict.__init__(self, agents)
|
400 |
self.name = None
|
lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (305 Bytes). View file
|
|
lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc
ADDED
Binary file (2.71 kB). View file
|
|
lagent/agents/react.py
CHANGED
@@ -12,6 +12,7 @@ from lagent.memory import Memory
|
|
12 |
from lagent.prompts.parsers.json_parser import JSONParser
|
13 |
from lagent.prompts.prompt_template import PromptTemplate
|
14 |
from lagent.schema import AgentMessage
|
|
|
15 |
|
16 |
select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
|
17 |
{action_info}
|
@@ -27,88 +28,96 @@ output_format_template = """如果使用工具请遵循以下格式回复:
|
|
27 |
|
28 |
class ReAct(Agent):
|
29 |
|
30 |
-
def __init__(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
**kwargs
|
43 |
-
):
|
44 |
self.max_turn = max_turn
|
45 |
self.finish_condition = finish_condition
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
llm=llm,
|
49 |
template=template.format(
|
50 |
-
action_info=json.dumps(self.actions.description()),
|
51 |
-
|
52 |
output_format=output_format,
|
53 |
memory=memory,
|
54 |
aggregator=aggregator,
|
55 |
hooks=hooks,
|
56 |
)
|
|
|
57 |
super().__init__(**kwargs)
|
58 |
|
59 |
-
def forward(self, message: AgentMessage,
|
60 |
for _ in range(self.max_turn):
|
61 |
-
message = self.select_agent(message
|
62 |
if self.finish_condition(message):
|
63 |
return message
|
64 |
-
message = self.actions(message
|
65 |
return message
|
66 |
|
67 |
|
68 |
class AsyncReAct(AsyncAgent):
|
69 |
|
70 |
-
def __init__(
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
**kwargs
|
83 |
-
):
|
84 |
self.max_turn = max_turn
|
85 |
self.finish_condition = finish_condition
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
llm=llm,
|
89 |
template=template.format(
|
90 |
-
action_info=json.dumps(self.actions.description()),
|
91 |
-
|
92 |
output_format=output_format,
|
93 |
memory=memory,
|
94 |
aggregator=aggregator,
|
95 |
hooks=hooks,
|
96 |
)
|
|
|
97 |
super().__init__(**kwargs)
|
98 |
|
99 |
-
async def forward(self, message: AgentMessage,
|
100 |
for _ in range(self.max_turn):
|
101 |
-
message = await self.select_agent(message
|
102 |
if self.finish_condition(message):
|
103 |
return message
|
104 |
-
message = await self.actions(message
|
105 |
return message
|
106 |
|
107 |
|
108 |
if __name__ == '__main__':
|
109 |
-
import
|
110 |
-
|
111 |
-
from lagent.llms import GPTAPI, AsyncGPTAPI
|
112 |
|
113 |
class ActionCall(BaseModel):
|
114 |
name: str = Field(description='调用的函数名称')
|
@@ -116,49 +125,37 @@ if __name__ == '__main__':
|
|
116 |
|
117 |
class ActionFormat(BaseModel):
|
118 |
thought_process: str = Field(
|
119 |
-
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。'
|
120 |
-
)
|
121 |
action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。')
|
122 |
|
123 |
class FinishFormat(BaseModel):
|
124 |
thought_process: str = Field(
|
125 |
-
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。'
|
126 |
-
)
|
127 |
conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
|
128 |
|
129 |
prompt_template = PromptTemplate(select_action_template)
|
130 |
-
output_format = JSONParser(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
agent = ReAct(
|
133 |
-
llm=
|
134 |
-
type=GPTAPI,
|
135 |
-
model_type='gpt-4o-2024-05-13',
|
136 |
-
max_new_tokens=4096,
|
137 |
-
proxies=dict(),
|
138 |
-
retry=1000,
|
139 |
-
),
|
140 |
template=prompt_template,
|
141 |
output_format=output_format,
|
142 |
-
aggregator=dict(type='
|
143 |
-
actions=[dict(type='
|
144 |
)
|
145 |
-
response = agent(
|
|
|
146 |
print(response)
|
147 |
response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
|
148 |
print(response)
|
149 |
-
|
150 |
-
async_agent = AsyncReAct(
|
151 |
-
llm=dict(
|
152 |
-
type=AsyncGPTAPI,
|
153 |
-
model_type='gpt-4o-2024-05-13',
|
154 |
-
max_new_tokens=4096,
|
155 |
-
proxies=dict(),
|
156 |
-
retry=1000,
|
157 |
-
),
|
158 |
-
template=prompt_template,
|
159 |
-
output_format=output_format,
|
160 |
-
aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'),
|
161 |
-
actions=[dict(type='lagent.actions.AsyncPythonInterpreter')],
|
162 |
-
)
|
163 |
-
response = asyncio.run(async_agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')))
|
164 |
-
print(async_agent.state_dict())
|
|
|
12 |
from lagent.prompts.parsers.json_parser import JSONParser
|
13 |
from lagent.prompts.prompt_template import PromptTemplate
|
14 |
from lagent.schema import AgentMessage
|
15 |
+
from lagent.utils import create_object
|
16 |
|
17 |
select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
|
18 |
{action_info}
|
|
|
28 |
|
29 |
class ReAct(Agent):
|
30 |
|
31 |
+
def __init__(self,
|
32 |
+
llm: Union[BaseLLM, Dict],
|
33 |
+
actions: Union[BaseAction, List[BaseAction]],
|
34 |
+
template: Union[PromptTemplate, str] = None,
|
35 |
+
memory: Dict = dict(type=Memory),
|
36 |
+
output_format: Dict = dict(type=JSONParser),
|
37 |
+
aggregator: Dict = dict(type=DefaultAggregator),
|
38 |
+
hooks: List = [dict(type=ActionPreprocessor)],
|
39 |
+
finish_condition: Callable[[AgentMessage], bool] = lambda m:
|
40 |
+
'conclusion' in m.content or 'conclusion' in m.formatted,
|
41 |
+
max_turn: int = 5,
|
42 |
+
**kwargs):
|
|
|
|
|
43 |
self.max_turn = max_turn
|
44 |
self.finish_condition = finish_condition
|
45 |
+
actions = dict(
|
46 |
+
type=ActionExecutor,
|
47 |
+
actions=actions,
|
48 |
+
hooks=hooks,
|
49 |
+
)
|
50 |
+
self.actions: ActionExecutor = create_object(actions)
|
51 |
+
select_agent = dict(
|
52 |
+
type=Agent,
|
53 |
llm=llm,
|
54 |
template=template.format(
|
55 |
+
action_info=json.dumps(self.actions.description()),
|
56 |
+
output_format=output_format.format_instruction()),
|
57 |
output_format=output_format,
|
58 |
memory=memory,
|
59 |
aggregator=aggregator,
|
60 |
hooks=hooks,
|
61 |
)
|
62 |
+
self.select_agent = create_object(select_agent)
|
63 |
super().__init__(**kwargs)
|
64 |
|
65 |
+
def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
|
66 |
for _ in range(self.max_turn):
|
67 |
+
message = self.select_agent(message)
|
68 |
if self.finish_condition(message):
|
69 |
return message
|
70 |
+
message = self.actions(message)
|
71 |
return message
|
72 |
|
73 |
|
74 |
class AsyncReAct(AsyncAgent):
|
75 |
|
76 |
+
def __init__(self,
|
77 |
+
llm: Union[BaseLLM, Dict],
|
78 |
+
actions: Union[BaseAction, List[BaseAction]],
|
79 |
+
template: Union[PromptTemplate, str] = None,
|
80 |
+
memory: Dict = dict(type=Memory),
|
81 |
+
output_format: Dict = dict(type=JSONParser),
|
82 |
+
aggregator: Dict = dict(type=DefaultAggregator),
|
83 |
+
hooks: List = [dict(type=ActionPreprocessor)],
|
84 |
+
finish_condition: Callable[[AgentMessage], bool] = lambda m:
|
85 |
+
'conclusion' in m.content or 'conclusion' in m.formatted,
|
86 |
+
max_turn: int = 5,
|
87 |
+
**kwargs):
|
|
|
|
|
88 |
self.max_turn = max_turn
|
89 |
self.finish_condition = finish_condition
|
90 |
+
actions = dict(
|
91 |
+
type=AsyncActionExecutor,
|
92 |
+
actions=actions,
|
93 |
+
hooks=hooks,
|
94 |
+
)
|
95 |
+
self.actions: AsyncActionExecutor = create_object(actions)
|
96 |
+
select_agent = dict(
|
97 |
+
type=AsyncAgent,
|
98 |
llm=llm,
|
99 |
template=template.format(
|
100 |
+
action_info=json.dumps(self.actions.description()),
|
101 |
+
output_format=output_format.format_instruction()),
|
102 |
output_format=output_format,
|
103 |
memory=memory,
|
104 |
aggregator=aggregator,
|
105 |
hooks=hooks,
|
106 |
)
|
107 |
+
self.select_agent = create_object(select_agent)
|
108 |
super().__init__(**kwargs)
|
109 |
|
110 |
+
async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
|
111 |
for _ in range(self.max_turn):
|
112 |
+
message = await self.select_agent(message)
|
113 |
if self.finish_condition(message):
|
114 |
return message
|
115 |
+
message = await self.actions(message)
|
116 |
return message
|
117 |
|
118 |
|
119 |
if __name__ == '__main__':
|
120 |
+
from lagent.llms import GPTAPI
|
|
|
|
|
121 |
|
122 |
class ActionCall(BaseModel):
|
123 |
name: str = Field(description='调用的函数名称')
|
|
|
125 |
|
126 |
class ActionFormat(BaseModel):
|
127 |
thought_process: str = Field(
|
128 |
+
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
|
|
|
129 |
action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。')
|
130 |
|
131 |
class FinishFormat(BaseModel):
|
132 |
thought_process: str = Field(
|
133 |
+
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
|
|
|
134 |
conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
|
135 |
|
136 |
prompt_template = PromptTemplate(select_action_template)
|
137 |
+
output_format = JSONParser(
|
138 |
+
output_format_template,
|
139 |
+
function_format=ActionFormat,
|
140 |
+
finish_format=FinishFormat)
|
141 |
+
|
142 |
+
llm = dict(
|
143 |
+
type=GPTAPI,
|
144 |
+
model_type='gpt-4o-2024-05-13',
|
145 |
+
key=None,
|
146 |
+
max_new_tokens=4096,
|
147 |
+
proxies=dict(),
|
148 |
+
retry=1000)
|
149 |
|
150 |
agent = ReAct(
|
151 |
+
llm=llm,
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
template=prompt_template,
|
153 |
output_format=output_format,
|
154 |
+
aggregator=dict(type='DefaultAggregator'),
|
155 |
+
actions=[dict(type='PythonInterpreter')],
|
156 |
)
|
157 |
+
response = agent(
|
158 |
+
AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
|
159 |
print(response)
|
160 |
response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
|
161 |
print(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lagent/agents/stream.py
CHANGED
@@ -15,30 +15,25 @@ from lagent.utils import create_object
|
|
15 |
|
16 |
API_PREFIX = (
|
17 |
"This is the subfunction for tool '{tool_name}', you can use this tool. "
|
18 |
-
'The description of this function is: \n{description}'
|
19 |
-
)
|
20 |
|
21 |
-
META_CN = '当开启工具以及代码时,根据需求选择合适的工具进行调用'
|
22 |
|
23 |
-
INTERPRETER_CN = (
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
'以及文件操作和数据导入(处理CSV、JSON等格式的文件)。'
|
31 |
-
)
|
32 |
|
33 |
-
PLUGIN_CN = (
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
'同时注意你可以使用的工具,不要随意捏造!'
|
38 |
-
)
|
39 |
|
40 |
|
41 |
-
def get_plugin_prompt(actions, api_desc_template=
|
42 |
plugin_descriptions = []
|
43 |
for action in actions if isinstance(actions, list) else [actions]:
|
44 |
action = create_object(action)
|
@@ -46,9 +41,20 @@ def get_plugin_prompt(actions, api_desc_template='{description}'):
|
|
46 |
if action.is_toolkit:
|
47 |
for api in action_desc['api_list']:
|
48 |
api['name'] = f"{action.name}.{api['name']}"
|
49 |
-
api['description'] = api_desc_template.format(
|
|
|
|
|
|
|
|
|
|
|
50 |
plugin_descriptions.append(api)
|
51 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
plugin_descriptions.append(action_desc)
|
53 |
return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
|
54 |
|
@@ -70,15 +76,17 @@ class AgentForInternLM(Agent):
|
|
70 |
parsers=[
|
71 |
dict(type=PluginParser, template=PLUGIN_CN),
|
72 |
dict(type=InterpreterParser, template=INTERPRETER_CN),
|
73 |
-
],
|
74 |
-
),
|
75 |
aggregator: Dict = dict(type=InternLMToolAggregator),
|
76 |
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
77 |
-
finish_condition: Callable[
|
|
|
|
|
78 |
max_turn: int = 4,
|
79 |
**kwargs,
|
80 |
):
|
81 |
-
|
|
|
82 |
llm=llm,
|
83 |
template=template,
|
84 |
output_format=output_format,
|
@@ -86,18 +94,22 @@ class AgentForInternLM(Agent):
|
|
86 |
aggregator=aggregator,
|
87 |
hooks=kwargs.pop('hooks', None),
|
88 |
)
|
89 |
-
self.
|
90 |
-
self.
|
|
|
|
|
|
|
91 |
if not (self.plugin_executor or self.interpreter_executor):
|
92 |
warnings.warn(
|
93 |
'Neither plugin nor interpreter executor is initialized. '
|
94 |
-
'An exception will be thrown when the agent call a tool.'
|
95 |
-
)
|
96 |
self.finish_condition = finish_condition
|
97 |
self.max_turn = max_turn
|
98 |
super().__init__(**kwargs)
|
99 |
|
100 |
def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
|
|
|
|
101 |
for _ in range(self.max_turn):
|
102 |
message = self.agent(message, session_id=session_id, **kwargs)
|
103 |
assert isinstance(message.formatted, dict)
|
@@ -115,10 +127,15 @@ class AgentForInternLM(Agent):
|
|
115 |
steps, tool_type = [], None
|
116 |
for msg in self.agent.memory.get_memory(session_id):
|
117 |
if msg.sender == self.agent.name:
|
118 |
-
steps.append(
|
|
|
119 |
if msg.formatted['tool_type']:
|
120 |
tool_type = msg.formatted['tool_type']
|
121 |
-
steps.append(
|
|
|
|
|
|
|
|
|
122 |
elif msg.sender != 'user':
|
123 |
feedback = dict(role='environment', content=msg.content)
|
124 |
if tool_type:
|
@@ -132,22 +149,23 @@ class MathCoder(AgentForInternLM):
|
|
132 |
def __init__(
|
133 |
self,
|
134 |
llm: Union[BaseLLM, Dict],
|
135 |
-
interpreter: dict = dict(
|
|
|
136 |
template: Union[str, dict, List[dict]] = None,
|
137 |
memory: Dict = dict(type=Memory),
|
138 |
output_format: Dict = dict(
|
139 |
type=InterpreterParser,
|
140 |
-
template=
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
),
|
147 |
-
),
|
148 |
aggregator: Dict = dict(type=InternLMToolAggregator),
|
149 |
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
150 |
-
finish_condition: Callable[
|
|
|
|
|
151 |
max_turn: int = 6,
|
152 |
**kwargs,
|
153 |
):
|
@@ -162,8 +180,7 @@ class MathCoder(AgentForInternLM):
|
|
162 |
action_hooks=action_hooks,
|
163 |
finish_condition=finish_condition,
|
164 |
max_turn=max_turn,
|
165 |
-
**kwargs
|
166 |
-
)
|
167 |
|
168 |
|
169 |
class AsyncAgentForInternLM(AsyncAgent):
|
@@ -183,15 +200,17 @@ class AsyncAgentForInternLM(AsyncAgent):
|
|
183 |
parsers=[
|
184 |
dict(type=PluginParser, template=PLUGIN_CN),
|
185 |
dict(type=InterpreterParser, template=INTERPRETER_CN),
|
186 |
-
],
|
187 |
-
),
|
188 |
aggregator: Dict = dict(type=InternLMToolAggregator),
|
189 |
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
190 |
-
finish_condition: Callable[
|
|
|
|
|
191 |
max_turn: int = 4,
|
192 |
**kwargs,
|
193 |
):
|
194 |
-
|
|
|
195 |
llm=llm,
|
196 |
template=template,
|
197 |
output_format=output_format,
|
@@ -199,20 +218,25 @@ class AsyncAgentForInternLM(AsyncAgent):
|
|
199 |
aggregator=aggregator,
|
200 |
hooks=kwargs.pop('hooks', None),
|
201 |
)
|
202 |
-
self.
|
203 |
-
self.
|
|
|
|
|
|
|
204 |
if not (self.plugin_executor or self.interpreter_executor):
|
205 |
warnings.warn(
|
206 |
'Neither plugin nor interpreter executor is initialized. '
|
207 |
-
'An exception will be thrown when the agent call a tool.'
|
208 |
-
)
|
209 |
self.finish_condition = finish_condition
|
210 |
self.max_turn = max_turn
|
211 |
super().__init__(**kwargs)
|
212 |
|
213 |
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
|
|
|
|
214 |
for _ in range(self.max_turn):
|
215 |
-
message = await self.agent(
|
|
|
216 |
assert isinstance(message.formatted, dict)
|
217 |
if self.finish_condition(message):
|
218 |
return message
|
@@ -228,10 +252,15 @@ class AsyncAgentForInternLM(AsyncAgent):
|
|
228 |
steps, tool_type = [], None
|
229 |
for msg in self.agent.memory.get_memory(session_id):
|
230 |
if msg.sender == self.agent.name:
|
231 |
-
steps.append(
|
|
|
232 |
if msg.formatted['tool_type']:
|
233 |
tool_type = msg.formatted['tool_type']
|
234 |
-
steps.append(
|
|
|
|
|
|
|
|
|
235 |
elif msg.sender != 'user':
|
236 |
feedback = dict(role='environment', content=msg.content)
|
237 |
if tool_type:
|
@@ -250,17 +279,17 @@ class AsyncMathCoder(AsyncAgentForInternLM):
|
|
250 |
memory: Dict = dict(type=Memory),
|
251 |
output_format: Dict = dict(
|
252 |
type=InterpreterParser,
|
253 |
-
template=
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
),
|
260 |
-
),
|
261 |
aggregator: Dict = dict(type=InternLMToolAggregator),
|
262 |
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
263 |
-
finish_condition: Callable[
|
|
|
|
|
264 |
max_turn: int = 6,
|
265 |
**kwargs,
|
266 |
):
|
@@ -275,13 +304,13 @@ class AsyncMathCoder(AsyncAgentForInternLM):
|
|
275 |
action_hooks=action_hooks,
|
276 |
finish_condition=finish_condition,
|
277 |
max_turn=max_turn,
|
278 |
-
**kwargs
|
279 |
-
)
|
280 |
|
281 |
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
282 |
try:
|
283 |
return await super().forward(message, session_id, **kwargs)
|
284 |
finally:
|
285 |
-
interpreter = next(
|
|
|
286 |
if interpreter.name == 'AsyncIPythonInterpreter':
|
287 |
await interpreter.close_session(session_id)
|
|
|
15 |
|
16 |
API_PREFIX = (
|
17 |
"This is the subfunction for tool '{tool_name}', you can use this tool. "
|
18 |
+
'The description of this function is: \n{description}')
|
|
|
19 |
|
20 |
+
META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用')
|
21 |
|
22 |
+
INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
|
23 |
+
'当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
|
24 |
+
'这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
|
25 |
+
'复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
|
26 |
+
'文本处理和分析(比如文本解析和自然语言处理),'
|
27 |
+
'机器学习和数据科学(用于展示模型训练和数据可视化),'
|
28 |
+
'以及文件操作和数据导入(处理CSV、JSON等格式的文件)。')
|
|
|
|
|
29 |
|
30 |
+
PLUGIN_CN = ('你可以使用如下工具:'
|
31 |
+
'\n{prompt}\n'
|
32 |
+
'如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
|
33 |
+
'同时注意你可以使用的工具,不要随意捏造!')
|
|
|
|
|
34 |
|
35 |
|
36 |
+
def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
|
37 |
plugin_descriptions = []
|
38 |
for action in actions if isinstance(actions, list) else [actions]:
|
39 |
action = create_object(action)
|
|
|
41 |
if action.is_toolkit:
|
42 |
for api in action_desc['api_list']:
|
43 |
api['name'] = f"{action.name}.{api['name']}"
|
44 |
+
api['description'] = api_desc_template.format(
|
45 |
+
tool_name=action.name, description=api['description'])
|
46 |
+
api['parameters'] = [
|
47 |
+
param for param in api['parameters']
|
48 |
+
if param['name'] in api['required']
|
49 |
+
]
|
50 |
plugin_descriptions.append(api)
|
51 |
else:
|
52 |
+
action_desc['description'] = api_desc_template.format(
|
53 |
+
tool_name=action.name, description=action_desc['description'])
|
54 |
+
action_desc['parameters'] = [
|
55 |
+
param for param in action_desc['parameters']
|
56 |
+
if param['name'] in action_desc['required']
|
57 |
+
]
|
58 |
plugin_descriptions.append(action_desc)
|
59 |
return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
|
60 |
|
|
|
76 |
parsers=[
|
77 |
dict(type=PluginParser, template=PLUGIN_CN),
|
78 |
dict(type=InterpreterParser, template=INTERPRETER_CN),
|
79 |
+
]),
|
|
|
80 |
aggregator: Dict = dict(type=InternLMToolAggregator),
|
81 |
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
82 |
+
finish_condition: Callable[
|
83 |
+
[AgentMessage],
|
84 |
+
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
|
85 |
max_turn: int = 4,
|
86 |
**kwargs,
|
87 |
):
|
88 |
+
agent = dict(
|
89 |
+
type=self._INTERNAL_AGENT_CLS,
|
90 |
llm=llm,
|
91 |
template=template,
|
92 |
output_format=output_format,
|
|
|
94 |
aggregator=aggregator,
|
95 |
hooks=kwargs.pop('hooks', None),
|
96 |
)
|
97 |
+
self.agent = create_object(agent)
|
98 |
+
self.plugin_executor = plugins and ActionExecutor(
|
99 |
+
plugins, hooks=action_hooks)
|
100 |
+
self.interpreter_executor = interpreter and ActionExecutor(
|
101 |
+
interpreter, hooks=action_hooks)
|
102 |
if not (self.plugin_executor or self.interpreter_executor):
|
103 |
warnings.warn(
|
104 |
'Neither plugin nor interpreter executor is initialized. '
|
105 |
+
'An exception will be thrown when the agent call a tool.')
|
|
|
106 |
self.finish_condition = finish_condition
|
107 |
self.max_turn = max_turn
|
108 |
super().__init__(**kwargs)
|
109 |
|
110 |
def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
111 |
+
if isinstance(message, str):
|
112 |
+
message = AgentMessage(sender='user', content=message)
|
113 |
for _ in range(self.max_turn):
|
114 |
message = self.agent(message, session_id=session_id, **kwargs)
|
115 |
assert isinstance(message.formatted, dict)
|
|
|
127 |
steps, tool_type = [], None
|
128 |
for msg in self.agent.memory.get_memory(session_id):
|
129 |
if msg.sender == self.agent.name:
|
130 |
+
steps.append(
|
131 |
+
dict(role='thought', content=msg.formatted['thought']))
|
132 |
if msg.formatted['tool_type']:
|
133 |
tool_type = msg.formatted['tool_type']
|
134 |
+
steps.append(
|
135 |
+
dict(
|
136 |
+
role='tool',
|
137 |
+
content=msg.formatted['action'],
|
138 |
+
name=tool_type))
|
139 |
elif msg.sender != 'user':
|
140 |
feedback = dict(role='environment', content=msg.content)
|
141 |
if tool_type:
|
|
|
149 |
def __init__(
|
150 |
self,
|
151 |
llm: Union[BaseLLM, Dict],
|
152 |
+
interpreter: dict = dict(
|
153 |
+
type=IPythonInteractive, timeout=20, max_out_len=8192),
|
154 |
template: Union[str, dict, List[dict]] = None,
|
155 |
memory: Dict = dict(type=Memory),
|
156 |
output_format: Dict = dict(
|
157 |
type=InterpreterParser,
|
158 |
+
template=
|
159 |
+
('Integrate step-by-step reasoning and Python code to solve math problems '
|
160 |
+
'using the following guidelines:\n'
|
161 |
+
'- Analyze the question and write jupyter code to solve the problem;\n'
|
162 |
+
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
163 |
+
'units. \n')),
|
|
|
|
|
164 |
aggregator: Dict = dict(type=InternLMToolAggregator),
|
165 |
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
166 |
+
finish_condition: Callable[
|
167 |
+
[AgentMessage],
|
168 |
+
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
|
169 |
max_turn: int = 6,
|
170 |
**kwargs,
|
171 |
):
|
|
|
180 |
action_hooks=action_hooks,
|
181 |
finish_condition=finish_condition,
|
182 |
max_turn=max_turn,
|
183 |
+
**kwargs)
|
|
|
184 |
|
185 |
|
186 |
class AsyncAgentForInternLM(AsyncAgent):
|
|
|
200 |
parsers=[
|
201 |
dict(type=PluginParser, template=PLUGIN_CN),
|
202 |
dict(type=InterpreterParser, template=INTERPRETER_CN),
|
203 |
+
]),
|
|
|
204 |
aggregator: Dict = dict(type=InternLMToolAggregator),
|
205 |
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
206 |
+
finish_condition: Callable[
|
207 |
+
[AgentMessage],
|
208 |
+
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
|
209 |
max_turn: int = 4,
|
210 |
**kwargs,
|
211 |
):
|
212 |
+
agent = dict(
|
213 |
+
type=self._INTERNAL_AGENT_CLS,
|
214 |
llm=llm,
|
215 |
template=template,
|
216 |
output_format=output_format,
|
|
|
218 |
aggregator=aggregator,
|
219 |
hooks=kwargs.pop('hooks', None),
|
220 |
)
|
221 |
+
self.agent = create_object(agent)
|
222 |
+
self.plugin_executor = plugins and AsyncActionExecutor(
|
223 |
+
plugins, hooks=action_hooks)
|
224 |
+
self.interpreter_executor = interpreter and AsyncActionExecutor(
|
225 |
+
interpreter, hooks=action_hooks)
|
226 |
if not (self.plugin_executor or self.interpreter_executor):
|
227 |
warnings.warn(
|
228 |
'Neither plugin nor interpreter executor is initialized. '
|
229 |
+
'An exception will be thrown when the agent call a tool.')
|
|
|
230 |
self.finish_condition = finish_condition
|
231 |
self.max_turn = max_turn
|
232 |
super().__init__(**kwargs)
|
233 |
|
234 |
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
235 |
+
if isinstance(message, str):
|
236 |
+
message = AgentMessage(sender='user', content=message)
|
237 |
for _ in range(self.max_turn):
|
238 |
+
message = await self.agent(
|
239 |
+
message, session_id=session_id, **kwargs)
|
240 |
assert isinstance(message.formatted, dict)
|
241 |
if self.finish_condition(message):
|
242 |
return message
|
|
|
252 |
steps, tool_type = [], None
|
253 |
for msg in self.agent.memory.get_memory(session_id):
|
254 |
if msg.sender == self.agent.name:
|
255 |
+
steps.append(
|
256 |
+
dict(role='thought', content=msg.formatted['thought']))
|
257 |
if msg.formatted['tool_type']:
|
258 |
tool_type = msg.formatted['tool_type']
|
259 |
+
steps.append(
|
260 |
+
dict(
|
261 |
+
role='tool',
|
262 |
+
content=msg.formatted['action'],
|
263 |
+
name=tool_type))
|
264 |
elif msg.sender != 'user':
|
265 |
feedback = dict(role='environment', content=msg.content)
|
266 |
if tool_type:
|
|
|
279 |
memory: Dict = dict(type=Memory),
|
280 |
output_format: Dict = dict(
|
281 |
type=InterpreterParser,
|
282 |
+
template=
|
283 |
+
('Integrate step-by-step reasoning and Python code to solve math problems '
|
284 |
+
'using the following guidelines:\n'
|
285 |
+
'- Analyze the question and write jupyter code to solve the problem;\n'
|
286 |
+
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
287 |
+
'units. \n')),
|
|
|
|
|
288 |
aggregator: Dict = dict(type=InternLMToolAggregator),
|
289 |
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
290 |
+
finish_condition: Callable[
|
291 |
+
[AgentMessage],
|
292 |
+
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
|
293 |
max_turn: int = 6,
|
294 |
**kwargs,
|
295 |
):
|
|
|
304 |
action_hooks=action_hooks,
|
305 |
finish_condition=finish_condition,
|
306 |
max_turn=max_turn,
|
307 |
+
**kwargs)
|
|
|
308 |
|
309 |
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
310 |
try:
|
311 |
return await super().forward(message, session_id, **kwargs)
|
312 |
finally:
|
313 |
+
interpreter = next(
|
314 |
+
iter(self.interpreter_executor.actions.values()))
|
315 |
if interpreter.name == 'AsyncIPythonInterpreter':
|
316 |
await interpreter.close_session(session_id)
|
lagent/distributed/http_serve/api_server.py
CHANGED
@@ -3,7 +3,6 @@ import os
|
|
3 |
import subprocess
|
4 |
import sys
|
5 |
import time
|
6 |
-
import threading
|
7 |
|
8 |
import aiohttp
|
9 |
import requests
|
@@ -78,21 +77,14 @@ class HTTPAgentServer(HTTPAgentClient):
|
|
78 |
stderr=subprocess.STDOUT,
|
79 |
text=True)
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
# Start log output thread
|
91 |
-
threading.Thread(target=log_output, args=(self.process.stdout,), daemon=True).start()
|
92 |
-
threading.Thread(target=log_output, args=(self.process.stderr,), daemon=True).start()
|
93 |
-
|
94 |
-
# Waiting for the service to start
|
95 |
-
while not self.service_started:
|
96 |
time.sleep(0.1)
|
97 |
|
98 |
def shutdown(self):
|
|
|
3 |
import subprocess
|
4 |
import sys
|
5 |
import time
|
|
|
6 |
|
7 |
import aiohttp
|
8 |
import requests
|
|
|
77 |
stderr=subprocess.STDOUT,
|
78 |
text=True)
|
79 |
|
80 |
+
while True:
|
81 |
+
output = self.process.stdout.readline()
|
82 |
+
if not output: # 如果读到 EOF,跳出循环
|
83 |
+
break
|
84 |
+
sys.stdout.write(output) # 打印到标准输出
|
85 |
+
sys.stdout.flush()
|
86 |
+
if 'Uvicorn running on' in output: # 根据实际输出调整
|
87 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
time.sleep(0.1)
|
89 |
|
90 |
def shutdown(self):
|
lagent/hooks/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (387 Bytes). View file
|
|
lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc
ADDED
Binary file (2.29 kB). View file
|
|
lagent/hooks/__pycache__/hook.cpython-310.pyc
ADDED
Binary file (1.53 kB). View file
|
|
lagent/hooks/__pycache__/logger.cpython-310.pyc
ADDED
Binary file (1.77 kB). View file
|
|
lagent/hooks/logger.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import random
|
|
|
2 |
|
3 |
from termcolor import COLORS, colored
|
4 |
|
@@ -7,10 +8,10 @@ from .hook import Hook
|
|
7 |
|
8 |
|
9 |
class MessageLogger(Hook):
|
10 |
-
|
|
|
11 |
self.logger = get_logger(
|
12 |
-
name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s'
|
13 |
-
)
|
14 |
self.sender2color = {}
|
15 |
|
16 |
def before_agent(self, agent, messages, session_id):
|
@@ -28,5 +29,9 @@ class MessageLogger(Hook):
|
|
28 |
|
29 |
def _process_message(self, message, session_id):
|
30 |
sender = message.sender
|
31 |
-
color = self.sender2color.setdefault(sender,
|
32 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import random
|
2 |
+
from typing import Optional
|
3 |
|
4 |
from termcolor import COLORS, colored
|
5 |
|
|
|
8 |
|
9 |
|
10 |
class MessageLogger(Hook):
|
11 |
+
|
12 |
+
def __init__(self, name: str = 'lagent'):
|
13 |
self.logger = get_logger(
|
14 |
+
name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s')
|
|
|
15 |
self.sender2color = {}
|
16 |
|
17 |
def before_agent(self, agent, messages, session_id):
|
|
|
29 |
|
30 |
def _process_message(self, message, session_id):
|
31 |
sender = message.sender
|
32 |
+
color = self.sender2color.setdefault(sender,
|
33 |
+
random.choice(list(COLORS)))
|
34 |
+
self.logger.info(
|
35 |
+
colored(
|
36 |
+
f'session id: {session_id}, message sender: {sender}\n'
|
37 |
+
f'{message.content}', color))
|
lagent/llms/__init__.py
CHANGED
@@ -1,15 +1,9 @@
|
|
1 |
-
from .anthropic_llm import AsyncClaudeAPI, ClaudeAPI
|
2 |
from .base_api import AsyncBaseAPILLM, BaseAPILLM
|
3 |
from .base_llm import AsyncBaseLLM, BaseLLM
|
4 |
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
|
5 |
-
from .lmdeploy_wrapper import (
|
6 |
-
|
7 |
-
|
8 |
-
AsyncLMDeployServer,
|
9 |
-
LMDeployClient,
|
10 |
-
LMDeployPipeline,
|
11 |
-
LMDeployServer,
|
12 |
-
)
|
13 |
from .meta_template import INTERNLM2_META
|
14 |
from .openai import GPTAPI, AsyncGPTAPI
|
15 |
from .sensenova import SensenovaAPI
|
@@ -35,6 +29,4 @@ __all__ = [
|
|
35 |
'VllmModel',
|
36 |
'AsyncVllmModel',
|
37 |
'SensenovaAPI',
|
38 |
-
'AsyncClaudeAPI',
|
39 |
-
'ClaudeAPI',
|
40 |
]
|
|
|
|
|
1 |
from .base_api import AsyncBaseAPILLM, BaseAPILLM
|
2 |
from .base_llm import AsyncBaseLLM, BaseLLM
|
3 |
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
|
4 |
+
from .lmdeploy_wrapper import (AsyncLMDeployClient, AsyncLMDeployPipeline,
|
5 |
+
AsyncLMDeployServer, LMDeployClient,
|
6 |
+
LMDeployPipeline, LMDeployServer)
|
|
|
|
|
|
|
|
|
|
|
7 |
from .meta_template import INTERNLM2_META
|
8 |
from .openai import GPTAPI, AsyncGPTAPI
|
9 |
from .sensenova import SensenovaAPI
|
|
|
29 |
'VllmModel',
|
30 |
'AsyncVllmModel',
|
31 |
'SensenovaAPI',
|
|
|
|
|
32 |
]
|