diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index fdd19973f50ba93629622e5de07bc88391110cbc..0000000000000000000000000000000000000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,12 +0,0 @@ -# To get started with Dependabot version updates, you'll need to specify which -# package ecosystems to update and where the package manifests are located. -# Please see the documentation for all configuration options: -# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file - -version: 2 -updates: - - package-ecosystem: "" # See documentation for possible values - directory: "/" # Location of package manifests - schedule: - interval: "weekly" - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f51fcc609c036d518712710b11fcefa1c38100b..5066bc4981e629097ecfe35bfab46540f668e2ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,20 @@ exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/ repos: - repo: https://github.com/PyCQA/flake8 - rev: 7.1.1 + rev: 7.0.0 hooks: - id: flake8 - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: - id: isort - args: ["--profile", "black", "--filter-files", "--line-width", "119"] - repo: https://github.com/psf/black - rev: 24.10.0 + rev: 22.8.0 hooks: - id: black args: ["--line-length", "119", "--skip-string-normalization"] - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: check-yaml @@ -29,7 +27,7 @@ repos: - id: mixed-line-ending args: ["--fix=lf"] - repo: https://github.com/executablebooks/mdformat - rev: 0.7.21 + rev: 0.7.17 hooks: - id: mdformat args: ["--number"] @@ -38,11 +36,11 @@ repos: - mdformat_frontmatter - linkify-it-py - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.2.6 hooks: - id: codespell - repo: https://github.com/asottile/pyupgrade - rev: v3.19.1 + rev: v3.15.0 hooks: - id: pyupgrade args: ["--py36-plus"] diff --git a/app.py b/app.py index a0a6d796e75dc3616b8576ce8de91cd02f6ca8f5..b332056e853f7d9fb3168f1822b2aabbd2e39820 100644 --- a/app.py +++ b/app.py @@ -2,8 +2,8 @@ Author: Highthoughts cht7613@gmail.com Date: 2025-01-30 11:02:01 LastEditors: Highthoughts cht7613@gmail.com -LastEditTime: 2025-01-30 11:02:13 -FilePath: \lagent\app.py +LastEditTime: 2025-01-30 11:41:16 +FilePath: \AgentTest\app.py Description: Copyright (c) 2025 by Cuihaitao, All Rights Reserved. diff --git a/examples/agent_api_web_demo.py b/examples/agent_api_web_demo.py index 66ff51bc81bd8885d8bb10cd6463a41c64483790..779fcc6fc537b3e1d06a866cfb0a5cafafc9e3cb 100644 --- a/examples/agent_api_web_demo.py +++ b/examples/agent_api_web_demo.py @@ -2,7 +2,8 @@ import copy import os from typing import List import streamlit as st -from lagent.actions import ArxivSearch +# from lagent.actions import ArxivSearch +from lagent.actions import ArxivSearch, WeatherQuery from lagent.prompts.parsers import PluginParser from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt from lagent.llms import GPTAPI @@ -17,6 +18,7 @@ class SessionState: # 初始化插件列表 action_list = [ ArxivSearch(), + WeatherQuery(), ] st.session_state['plugin_map'] = {action.name: action for action in action_list} st.session_state['model_map'] = {} # 存储模型实例 @@ -50,7 +52,7 @@ class StreamlitUI: # page_title='lagent-web', # page_icon='./docs/imgs/lagent_icon.png' # ) - # st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') + st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') def setup_sidebar(self): """设置侧边栏,选择模型和插件。""" diff --git a/lagent.egg-info/PKG-INFO b/lagent.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..7e1547f7dc2c1b2a3a8a41b2af0b76fe149c6b5c --- /dev/null +++ b/lagent.egg-info/PKG-INFO @@ -0,0 +1,608 @@ +Metadata-Version: 2.2 +Name: lagent +Version: 0.5.0rc1 +Summary: A lightweight framework for building LLM-based agents +Home-page: https://github.com/InternLM/lagent +License: Apache 2.0 +Keywords: artificial general intelligence,agent,agi,llm +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: aiohttp +Requires-Dist: arxiv +Requires-Dist: asyncache +Requires-Dist: asyncer +Requires-Dist: distro +Requires-Dist: duckduckgo_search==5.3.1b1 +Requires-Dist: filelock +Requires-Dist: func_timeout +Requires-Dist: griffe<1.0 +Requires-Dist: json5 +Requires-Dist: jsonschema +Requires-Dist: jupyter==1.0.0 +Requires-Dist: jupyter_client==8.6.2 +Requires-Dist: jupyter_core==5.7.2 +Requires-Dist: pydantic==2.6.4 +Requires-Dist: requests +Requires-Dist: termcolor +Requires-Dist: tiktoken +Requires-Dist: timeout-decorator +Requires-Dist: typing-extensions +Provides-Extra: all +Requires-Dist: google-search-results; extra == "all" +Requires-Dist: lmdeploy>=0.2.5; extra == "all" +Requires-Dist: pillow; extra == "all" +Requires-Dist: python-pptx; extra == "all" +Requires-Dist: timeout_decorator; extra == "all" +Requires-Dist: torch; extra == "all" +Requires-Dist: transformers<=4.40,>=4.34; extra == "all" +Requires-Dist: vllm>=0.3.3; extra == "all" +Requires-Dist: aiohttp; extra == "all" +Requires-Dist: arxiv; extra == "all" +Requires-Dist: asyncache; extra == "all" +Requires-Dist: asyncer; extra == "all" +Requires-Dist: distro; extra == "all" +Requires-Dist: duckduckgo_search==5.3.1b1; extra == "all" +Requires-Dist: filelock; extra == "all" +Requires-Dist: func_timeout; extra == "all" +Requires-Dist: griffe<1.0; extra == "all" +Requires-Dist: json5; extra == "all" +Requires-Dist: jsonschema; extra == "all" +Requires-Dist: jupyter==1.0.0; extra == "all" +Requires-Dist: jupyter_client==8.6.2; extra == "all" +Requires-Dist: jupyter_core==5.7.2; extra == "all" +Requires-Dist: pydantic==2.6.4; extra == "all" +Requires-Dist: requests; extra == "all" +Requires-Dist: termcolor; extra == "all" +Requires-Dist: tiktoken; extra == "all" +Requires-Dist: timeout-decorator; extra == "all" +Requires-Dist: typing-extensions; extra == "all" +Provides-Extra: optional +Requires-Dist: google-search-results; extra == "optional" +Requires-Dist: lmdeploy>=0.2.5; extra == "optional" +Requires-Dist: pillow; extra == "optional" +Requires-Dist: python-pptx; extra == "optional" +Requires-Dist: timeout_decorator; extra == "optional" +Requires-Dist: torch; extra == "optional" +Requires-Dist: transformers<=4.40,>=4.34; extra == "optional" +Requires-Dist: vllm>=0.3.3; extra == "optional" +Dynamic: description +Dynamic: description-content-type +Dynamic: home-page +Dynamic: keywords +Dynamic: license +Dynamic: provides-extra +Dynamic: requires-dist +Dynamic: summary + +
+
+ + +[![docs](https://img.shields.io/badge/docs-latest-blue)](https://lagent.readthedocs.io/en/latest/) +[![PyPI](https://img.shields.io/pypi/v/lagent)](https://pypi.org/project/lagent) +[![license](https://img.shields.io/github/license/InternLM/lagent.svg)](https://github.com/InternLM/lagent/tree/main/LICENSE) +[![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lagent)](https://github.com/InternLM/lagent/issues) +[![open issues](https://img.shields.io/github/issues-raw/InternLM/lagent)](https://github.com/InternLM/lagent/issues) +![Visitors](https://api.visitorbadge.io/api/visitors?path=InternLM%2Flagent%20&countColor=%23263759&style=flat) +![GitHub forks](https://img.shields.io/github/forks/InternLM/lagent) +![GitHub Repo stars](https://img.shields.io/github/stars/InternLM/lagent) +![GitHub contributors](https://img.shields.io/github/contributors/InternLM/lagent) + +
+ +

+ 👋 join us on 𝕏 (Twitter), Discord and WeChat +

+ +## Installation + +Install from source: + +```bash +git clone https://github.com/InternLM/lagent.git +cd lagent +pip install -e . +``` + +## Usage + +Lagent is inspired by the design philosophy of PyTorch. We expect that the analogy of neural network layers will make the workflow clearer and more intuitive, so users only need to focus on creating layers and defining message passing between them in a Pythonic way. This is a simple tutorial to get you quickly started with building multi-agent applications. + +### Models as Agents + +Agents use `AgentMessage` for communication. + +```python +from typing import Dict, List +from lagent.agents import Agent +from lagent.schema import AgentMessage +from lagent.llms import VllmModel, INTERNLM2_META + +llm = VllmModel( + path='Qwen/Qwen2-7B-Instruct', + meta_template=INTERNLM2_META, + tp=1, + top_k=1, + temperature=1.0, + stop_words=['<|im_end|>'], + max_new_tokens=1024, +) +system_prompt = '你的回答只能从“典”、“孝”、“急”三个字中选一个。' +agent = Agent(llm, system_prompt) + +user_msg = AgentMessage(sender='user', content='今天天气情况') +bot_msg = agent(user_msg) +print(bot_msg) +``` + +``` +content='急' sender='Agent' formatted=None extra_info=None type=None receiver=None stream_state= +``` + +### Memory as State + +Both input and output messages will be added to the memory of `Agent` in each forward pass. This is performed in `__call__` rather than `forward`. See the following pseudo code + +```python + def __call__(self, *message): + message = pre_hooks(message) + add_memory(message) + message = self.forward(*message) + add_memory(message) + message = post_hooks(message) + return message +``` + +Inspect the memory in two ways + +```python +memory: List[AgentMessage] = agent.memory.get_memory() +print(memory) +print('-' * 120) +dumped_memory: Dict[str, List[dict]] = agent.state_dict() +print(dumped_memory['memory']) +``` + +``` +[AgentMessage(content='今天天气情况', sender='user', formatted=None, extra_info=None, type=None, receiver=None, stream_state=), AgentMessage(content='急', sender='Agent', formatted=None, extra_info=None, type=None, receiver=None, stream_state=)] +------------------------------------------------------------------------------------------------------------------------ +[{'content': '今天天气情况', 'sender': 'user', 'formatted': None, 'extra_info': None, 'type': None, 'receiver': None, 'stream_state': }, {'content': '急', 'sender': 'Agent', 'formatted': None, 'extra_info': None, 'type': None, 'receiver': None, 'stream_state': }] +``` + +Clear the memory of this session(`session_id=0` by default): + +```python +agent.memory.reset() +``` + +### Custom Message Aggregation + +`DefaultAggregator` is called under the hood to assemble and convert `AgentMessage` to OpenAI message format. + +```python + def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: + formatted_messages = self.aggregator.aggregate( + self.memory.get(session_id), + self.name, + self.output_format, + self.template, + ) + llm_response = self.llm.chat(formatted_messages, **kwargs) + ... +``` + +Implement a simple aggregator that can receive few-shots + +```python +from typing import List, Union +from lagent.memory import Memory +from lagent.prompts import StrParser +from lagent.agents.aggregator import DefaultAggregator + +class FewshotAggregator(DefaultAggregator): + def __init__(self, few_shot: List[dict] = None): + self.few_shot = few_shot or [] + + def aggregate(self, + messages: Memory, + name: str, + parser: StrParser = None, + system_instruction: Union[str, dict, List[dict]] = None) -> List[dict]: + _message = [] + if system_instruction: + _message.extend( + self.aggregate_system_intruction(system_instruction)) + _message.extend(self.few_shot) + messages = messages.get_memory() + for message in messages: + if message.sender == name: + _message.append( + dict(role='assistant', content=str(message.content))) + else: + user_message = message.content + if len(_message) > 0 and _message[-1]['role'] == 'user': + _message[-1]['content'] += user_message + else: + _message.append(dict(role='user', content=user_message)) + return _message + +agent = Agent( + llm, + aggregator=FewshotAggregator( + [ + {"role": "user", "content": "今天天气"}, + {"role": "assistant", "content": "【晴】"}, + ] + ) +) +user_msg = AgentMessage(sender='user', content='昨天天气') +bot_msg = agent(user_msg) +print(bot_msg) +``` + +``` +content='【多云转晴,夜间有轻微降温】' sender='Agent' formatted=None extra_info=None type=None receiver=None stream_state= +``` + +### Flexible Response Formatting + +In `AgentMessage`, `formatted` is reserved to store information parsed by `output_format` from the model output. + +```python + def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: + ... + llm_response = self.llm.chat(formatted_messages, **kwargs) + if self.output_format: + formatted_messages = self.output_format.parse_response(llm_response) + return AgentMessage( + sender=self.name, + content=llm_response, + formatted=formatted_messages, + ) + ... +``` + +Use a tool parser as follows + +````python +from lagent.prompts.parsers import ToolParser + +system_prompt = "逐步分析并编写Python代码解决以下问题。" +parser = ToolParser(tool_type='code interpreter', begin='```python\n', end='\n```\n') +llm.gen_params['stop_words'].append('\n```\n') +agent = Agent(llm, system_prompt, output_format=parser) + +user_msg = AgentMessage( + sender='user', + content='Marie is thinking of a multiple of 63, while Jay is thinking of a ' + 'factor of 63. They happen to be thinking of the same number. There are ' + 'two possibilities for the number that each of them is thinking of, one ' + 'positive and one negative. Find the product of these two numbers.') +bot_msg = agent(user_msg) +print(bot_msg.model_dump_json(indent=4)) +```` + +```` +{ + "content": "首先,我们需要找出63的所有正因数和负因数。63的正因数可以通过分解63的质因数来找出,即\\(63 = 3^2 \\times 7\\)。因此,63的正因数包括1, 3, 7, 9, 21, 和 63。对于负因数,我们只需将上述正因数乘以-1。\n\n接下来,我们需要找出与63的正因数相乘的结果为63的数,以及与63的负因数相乘的结果为63的数。这可以通过将63除以每个正因数和负因数来实现。\n\n最后,我们将找到的两个数相乘得到最终答案。\n\n下面是Python代码实现:\n\n```python\ndef find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)", + "sender": "Agent", + "formatted": { + "tool_type": "code interpreter", + "thought": "首先,我们需要找出63的所有正因数和负因数。63的正因数可以通过分解63的质因数来找出,即\\(63 = 3^2 \\times 7\\)。因此,63的正因数包括1, 3, 7, 9, 21, 和 63。对于负因数,我们只需将上述正因数乘以-1。\n\n接下来,我们需要找出与63的正因数相乘的结果为63的数,以及与63的负因数相乘的结果为63的数。这可以通过将63除以每个正因数和负因数来实现。\n\n最后,我们将找到的两个数相乘得到最终答案。\n\n下面是Python代码实现:\n\n", + "action": "def find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)", + "status": 1 + }, + "extra_info": null, + "type": null, + "receiver": null, + "stream_state": 0 +} +```` + +### Consistency of Tool Calling + +`ActionExecutor` uses the same communication data structure as `Agent`, but requires the content of input `AgentMessage` to be a dict containing: + +- `name`: tool name, e.g. `'IPythonInterpreter'`, `'WebBrowser.search'`. +- `parameters`: keyword arguments of the tool API, e.g. `{'command': 'import math;math.sqrt(2)'}`, `{'query': ['recent progress in AI']}`. + +You can register custom hooks for message conversion. + +```python +from lagent.hooks import Hook +from lagent.schema import ActionReturn, ActionStatusCode, AgentMessage +from lagent.actions import ActionExecutor, IPythonInteractive + +class CodeProcessor(Hook): + def before_action(self, executor, message, session_id): + message = message.copy(deep=True) + message.content = dict( + name='IPythonInteractive', parameters={'command': message.formatted['action']} + ) + return message + + def after_action(self, executor, message, session_id): + action_return = message.content + if isinstance(action_return, ActionReturn): + if action_return.state == ActionStatusCode.SUCCESS: + response = action_return.format_result() + else: + response = action_return.errmsg + else: + response = action_return + message.content = response + return message + +executor = ActionExecutor(actions=[IPythonInteractive()], hooks=[CodeProcessor()]) +bot_msg = AgentMessage( + sender='Agent', + content='首先,我们需要...', + formatted={ + 'tool_type': 'code interpreter', + 'thought': '首先,我们需要...', + 'action': 'def find_numbers():\n # 正因数\n positive_factors = [1, 3, 7, 9, 21, 63]\n # 负因数\n negative_factors = [-1, -3, -7, -9, -21, -63]\n \n # 找到与正因数相乘的结果为63的数\n positive_numbers = [63 / factor for factor in positive_factors]\n # 找到与负因数相乘的结果为63的数\n negative_numbers = [-63 / factor for factor in negative_factors]\n \n # 计算两个数的乘积\n product = positive_numbers[0] * negative_numbers[0]\n \n return product\n\nresult = find_numbers()\nprint(result)', + 'status': 1 + }) +executor_msg = executor(bot_msg) +print(executor_msg) +``` + +``` +content='3969.0' sender='ActionExecutor' formatted=None extra_info=None type=None receiver=None stream_state= +``` + +**For convenience, Lagent provides `InternLMActionProcessor` which is adapted to messages formatted by `ToolParser` as mentioned above.** + +### Dual Interfaces + +Lagent adopts dual interface design, where almost every component(LLMs, actions, action executors...) has the corresponding asynchronous variant by prefixing its identifier with 'Async'. It is recommended to use synchronous agents for debugging and asynchronous ones for large-scale inference to make the most of idle CPU and GPU resources. + +However, make sure the internal consistency of agents, i.e. asynchronous agents should be equipped with asynchronous LLMs and asynchronous action executors that drive asynchronous tools. + +```python +from lagent.llms import VllmModel, AsyncVllmModel, LMDeployPipeline, AsyncLMDeployPipeline +from lagent.actions import ActionExecutor, AsyncActionExecutor, WebBrowser, AsyncWebBrowser +from lagent.agents import Agent, AsyncAgent, AgentForInternLM, AsyncAgentForInternLM +``` + +______________________________________________________________________ + +## Practice + +- **Try to implement `forward` instead of `__call__` of subclasses unless necessary.** +- **Always include the `session_id` argument explicitly, which is designed for isolation of memory, LLM requests and tool invocation(e.g. maintain multiple independent IPython environments) in concurrency.** + +### Single Agent + +Math agents that solve problems by programming + +````python +from lagent.agents.aggregator import InternLMToolAggregator + +class Coder(Agent): + def __init__(self, model_path, system_prompt, max_turn=3): + super().__init__() + llm = VllmModel( + path=model_path, + meta_template=INTERNLM2_META, + tp=1, + top_k=1, + temperature=1.0, + stop_words=['\n```\n', '<|im_end|>'], + max_new_tokens=1024, + ) + self.agent = Agent( + llm, + system_prompt, + output_format=ToolParser( + tool_type='code interpreter', begin='```python\n', end='\n```\n' + ), + # `InternLMToolAggregator` is adapted to `ToolParser` for aggregating + # messages with tool invocations and execution results + aggregator=InternLMToolAggregator(), + ) + self.executor = ActionExecutor([IPythonInteractive()], hooks=[CodeProcessor()]) + self.max_turn = max_turn + + def forward(self, message: AgentMessage, session_id=0) -> AgentMessage: + for _ in range(self.max_turn): + message = self.agent(message, session_id=session_id) + if message.formatted['tool_type'] is None: + return message + message = self.executor(message, session_id=session_id) + return message + +coder = Coder('Qwen/Qwen2-7B-Instruct', 'Solve the problem step by step with assistance of Python code') +query = AgentMessage( + sender='user', + content='Find the projection of $\\mathbf{a}$ onto $\\mathbf{b} = ' + '\\begin{pmatrix} 1 \\\\ -3 \\end{pmatrix}$ if $\\mathbf{a} \\cdot \\mathbf{b} = 2.$' +) +answer = coder(query) +print(answer.content) +print('-' * 120) +for msg in coder.state_dict()['agent.memory']: + print('*' * 80) + print(f'{msg["sender"]}:\n\n{msg["content"]}') +```` + +### Multiple Agents + +Asynchronous blogging agents that improve writing quality by self-refinement ([original AutoGen example](https://microsoft.github.io/autogen/0.2/docs/topics/prompting-and-reasoning/reflection/)) + +```python +import asyncio +import os +from lagent.llms import AsyncGPTAPI +from lagent.agents import AsyncAgent +os.environ['OPENAI_API_KEY'] = 'YOUR_API_KEY' + +class PrefixedMessageHook(Hook): + def __init__(self, prefix: str, senders: list = None): + self.prefix = prefix + self.senders = senders or [] + + def before_agent(self, agent, messages, session_id): + for message in messages: + if message.sender in self.senders: + message.content = self.prefix + message.content + +class AsyncBlogger(AsyncAgent): + def __init__(self, model_path, writer_prompt, critic_prompt, critic_prefix='', max_turn=3): + super().__init__() + llm = AsyncGPTAPI(model_type=model_path, retry=5, max_new_tokens=2048) + self.writer = AsyncAgent(llm, writer_prompt, name='writer') + self.critic = AsyncAgent( + llm, critic_prompt, name='critic', hooks=[PrefixedMessageHook(critic_prefix, ['writer'])] + ) + self.max_turn = max_turn + + async def forward(self, message: AgentMessage, session_id=0) -> AgentMessage: + for _ in range(self.max_turn): + message = await self.writer(message, session_id=session_id) + message = await self.critic(message, session_id=session_id) + return await self.writer(message, session_id=session_id) + +blogger = AsyncBlogger( + 'gpt-4o-2024-05-13', + writer_prompt="You are an writing assistant tasked to write engaging blogpost. You try to generate the best blogpost possible for the user's request. " + "If the user provides critique, then respond with a revised version of your previous attempts", + critic_prompt="Generate critique and recommendations on the writing. Provide detailed recommendations, including requests for length, depth, style, etc..", + critic_prefix='Reflect and provide critique on the following writing. \n\n', +) +user_prompt = ( + "Write an engaging blogpost on the recent updates in {topic}. " + "The blogpost should be engaging and understandable for general audience. " + "Should have more than 3 paragraphes but no longer than 1000 words.") +bot_msgs = asyncio.get_event_loop().run_until_complete( + asyncio.gather( + *[ + blogger(AgentMessage(sender='user', content=user_prompt.format(topic=topic)), session_id=i) + for i, topic in enumerate(['AI', 'Biotechnology', 'New Energy', 'Video Games', 'Pop Music']) + ] + ) +) +print(bot_msgs[0].content) +print('-' * 120) +for msg in blogger.state_dict(session_id=0)['writer.memory']: + print('*' * 80) + print(f'{msg["sender"]}:\n\n{msg["content"]}') +print('-' * 120) +for msg in blogger.state_dict(session_id=0)['critic.memory']: + print('*' * 80) + print(f'{msg["sender"]}:\n\n{msg["content"]}') +``` + +A multi-agent workflow that performs information retrieval, data collection and chart plotting ([original LangGraph example](https://vijaykumarkartha.medium.com/multiple-ai-agents-creating-multi-agent-workflows-using-langgraph-and-langchain-0587406ec4e6)) + +
+ +
+ +````python +import json +from lagent.actions import IPythonInterpreter, WebBrowser, ActionExecutor +from lagent.agents.stream import get_plugin_prompt +from lagent.llms import GPTAPI +from lagent.hooks import InternLMActionProcessor + +TOOL_TEMPLATE = ( + "You are a helpful AI assistant, collaborating with other assistants. Use the provided tools to progress" + " towards answering the question. If you are unable to fully answer, that's OK, another assistant with" + " different tools will help where you left off. Execute what you can to make progress. If you or any of" + " the other assistants have the final answer or deliverable, prefix your response with {finish_pattern}" + " so the team knows to stop. You have access to the following tools:\n{tool_description}\nPlease provide" + " your thought process when you need to use a tool, followed by the call statement in this format:" + "\n{invocation_format}\\\\n**{system_prompt}**" +) + +class DataVisualizer(Agent): + def __init__(self, model_path, research_prompt, chart_prompt, finish_pattern="Final Answer", max_turn=10): + super().__init__() + llm = GPTAPI(model_path, key='YOUR_OPENAI_API_KEY', retry=5, max_new_tokens=1024, stop_words=["```\n"]) + interpreter, browser = IPythonInterpreter(), WebBrowser("BingSearch", api_key="YOUR_BING_API_KEY") + self.researcher = Agent( + llm, + TOOL_TEMPLATE.format( + finish_pattern=finish_pattern, + tool_description=get_plugin_prompt(browser), + invocation_format='```json\n{"name": {{tool name}}, "parameters": {{keyword arguments}}}\n```\n', + system_prompt=research_prompt, + ), + output_format=ToolParser( + "browser", + begin="```json\n", + end="\n```\n", + validate=lambda x: json.loads(x.rstrip('`')), + ), + aggregator=InternLMToolAggregator(), + name="researcher", + ) + self.charter = Agent( + llm, + TOOL_TEMPLATE.format( + finish_pattern=finish_pattern, + tool_description=interpreter.name, + invocation_format='```python\n{{code}}\n```\n', + system_prompt=chart_prompt, + ), + output_format=ToolParser( + "interpreter", + begin="```python\n", + end="\n```\n", + validate=lambda x: x.rstrip('`'), + ), + aggregator=InternLMToolAggregator(), + name="charter", + ) + self.executor = ActionExecutor([interpreter, browser], hooks=[InternLMActionProcessor()]) + self.finish_pattern = finish_pattern + self.max_turn = max_turn + + def forward(self, message, session_id=0): + for _ in range(self.max_turn): + message = self.researcher(message, session_id=session_id, stop_words=["```\n", "```python"]) # override llm stop words + while message.formatted["tool_type"]: + message = self.executor(message, session_id=session_id) + message = self.researcher(message, session_id=session_id, stop_words=["```\n", "```python"]) + if self.finish_pattern in message.content: + return message + message = self.charter(message) + while message.formatted["tool_type"]: + message = self.executor(message, session_id=session_id) + message = self.charter(message, session_id=session_id) + if self.finish_pattern in message.content: + return message + return message + +visualizer = DataVisualizer( + "gpt-4o-2024-05-13", + research_prompt="You should provide accurate data for the chart generator to use.", + chart_prompt="Any charts you display will be visible by the user.", +) +user_msg = AgentMessage( + sender='user', + content="Fetch the China's GDP over the past 5 years, then draw a line graph of it. Once you code it up, finish.") +bot_msg = visualizer(user_msg) +print(bot_msg.content) +json.dump(visualizer.state_dict(), open('visualizer.json', 'w'), ensure_ascii=False, indent=4) +```` + +## Citation + +If you find this project useful in your research, please consider cite: + +```latex +@misc{lagent2023, + title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents}, + author={Lagent Developer Team}, + howpublished = {\url{https://github.com/InternLM/lagent}}, + year={2023} +} +``` + +## License + +This project is released under the [Apache 2.0 license](LICENSE). + +

🔼 Back to top

diff --git a/lagent.egg-info/SOURCES.txt b/lagent.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..7c2d07daed3cc5f3d3c344a618941479237907eb --- /dev/null +++ b/lagent.egg-info/SOURCES.txt @@ -0,0 +1,71 @@ +LICENSE +MANIFEST.in +README.md +setup.cfg +setup.py +lagent/__init__.py +lagent/schema.py +lagent/version.py +lagent.egg-info/PKG-INFO +lagent.egg-info/SOURCES.txt +lagent.egg-info/dependency_links.txt +lagent.egg-info/requires.txt +lagent.egg-info/top_level.txt +lagent/actions/__init__.py +lagent/actions/action_executor.py +lagent/actions/arxiv_search.py +lagent/actions/base_action.py +lagent/actions/bing_map.py +lagent/actions/builtin_actions.py +lagent/actions/google_scholar_search.py +lagent/actions/google_search.py +lagent/actions/ipython_interactive.py +lagent/actions/ipython_interpreter.py +lagent/actions/ipython_manager.py +lagent/actions/parser.py +lagent/actions/ppt.py +lagent/actions/python_interpreter.py +lagent/actions/web_browser.py +lagent/agents/__init__.py +lagent/agents/agent.py +lagent/agents/react.py +lagent/agents/stream.py +lagent/agents/aggregator/__init__.py +lagent/agents/aggregator/default_aggregator.py +lagent/agents/aggregator/tool_aggregator.py +lagent/distributed/__init__.py +lagent/distributed/http_serve/__init__.py +lagent/distributed/http_serve/api_server.py +lagent/distributed/http_serve/app.py +lagent/distributed/ray_serve/__init__.py +lagent/distributed/ray_serve/ray_warpper.py +lagent/hooks/__init__.py +lagent/hooks/action_preprocessor.py +lagent/hooks/hook.py +lagent/hooks/logger.py +lagent/llms/__init__.py +lagent/llms/base_api.py +lagent/llms/base_llm.py +lagent/llms/huggingface.py +lagent/llms/lmdeploy_wrapper.py +lagent/llms/meta_template.py +lagent/llms/openai.py +lagent/llms/sensenova.py +lagent/llms/vllm_wrapper.py +lagent/memory/__init__.py +lagent/memory/base_memory.py +lagent/memory/manager.py +lagent/prompts/__init__.py +lagent/prompts/prompt_template.py +lagent/prompts/parsers/__init__.py +lagent/prompts/parsers/custom_parser.py +lagent/prompts/parsers/json_parser.py +lagent/prompts/parsers/str_parser.py +lagent/prompts/parsers/tool_parser.py +lagent/utils/__init__.py +lagent/utils/gen_key.py +lagent/utils/package.py +lagent/utils/util.py +requirements/docs.txt +requirements/optional.txt +requirements/runtime.txt \ No newline at end of file diff --git a/lagent.egg-info/dependency_links.txt b/lagent.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/lagent.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/lagent.egg-info/requires.txt b/lagent.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..cfd987460eac9660713576d83741d17b3022d711 --- /dev/null +++ b/lagent.egg-info/requires.txt @@ -0,0 +1,59 @@ +aiohttp +arxiv +asyncache +asyncer +distro +duckduckgo_search==5.3.1b1 +filelock +func_timeout +griffe<1.0 +json5 +jsonschema +jupyter==1.0.0 +jupyter_client==8.6.2 +jupyter_core==5.7.2 +pydantic==2.6.4 +requests +termcolor +tiktoken +timeout-decorator +typing-extensions + +[all] +google-search-results +lmdeploy>=0.2.5 +pillow +python-pptx +timeout_decorator +torch +transformers<=4.40,>=4.34 +vllm>=0.3.3 +aiohttp +arxiv +asyncache +asyncer +distro +duckduckgo_search==5.3.1b1 +filelock +func_timeout +griffe<1.0 +json5 +jsonschema +jupyter==1.0.0 +jupyter_client==8.6.2 +jupyter_core==5.7.2 +pydantic==2.6.4 +requests +termcolor +tiktoken +typing-extensions + +[optional] +google-search-results +lmdeploy>=0.2.5 +pillow +python-pptx +timeout_decorator +torch +transformers<=4.40,>=4.34 +vllm>=0.3.3 diff --git a/lagent.egg-info/top_level.txt b/lagent.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..9273dc63a1927785084010f14533cbb6197c40d7 --- /dev/null +++ b/lagent.egg-info/top_level.txt @@ -0,0 +1 @@ +lagent diff --git a/lagent/__pycache__/__init__.cpython-310.pyc b/lagent/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab8bb2ca3b25fe54e8a1d311ddad0f28de2649db Binary files /dev/null and b/lagent/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/__pycache__/schema.cpython-310.pyc b/lagent/__pycache__/schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f56e32f0fefd5d903a8dbba131969c7bc14a2c4f Binary files /dev/null and b/lagent/__pycache__/schema.cpython-310.pyc differ diff --git a/lagent/__pycache__/version.cpython-310.pyc b/lagent/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1515fac2badaad54607c7814d572d7b432f1389e Binary files /dev/null and b/lagent/__pycache__/version.cpython-310.pyc differ diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index b75a226295716828f8424133db1e9f1f6e622d64..2de8c5fb2dbd349a1c9a157a05e1e6406b54e08a 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -1,6 +1,6 @@ from .action_executor import ActionExecutor, AsyncActionExecutor from .arxiv_search import ArxivSearch, AsyncArxivSearch -from .base_action import AsyncActionMixin, BaseAction, tool_api +from .base_action import BaseAction, tool_api from .bing_map import AsyncBINGMap, BINGMap from .builtin_actions import FinishAction, InvalidAction, NoAction from .google_scholar_search import AsyncGoogleScholar, GoogleScholar @@ -14,34 +14,23 @@ from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter from .web_browser import AsyncWebBrowser, WebBrowser __all__ = [ - 'BaseAction', - 'ActionExecutor', - 'AsyncActionExecutor', - 'InvalidAction', - 'FinishAction', - 'NoAction', - 'BINGMap', - 'AsyncBINGMap', - 'ArxivSearch', - 'AsyncArxivSearch', - 'GoogleSearch', - 'AsyncGoogleSearch', - 'GoogleScholar', - 'AsyncGoogleScholar', - 'IPythonInterpreter', - 'AsyncIPythonInterpreter', - 'IPythonInteractive', - 'AsyncIPythonInteractive', - 'IPythonInteractiveManager', - 'PythonInterpreter', - 'AsyncPythonInterpreter', - 'PPT', - 'AsyncPPT', - 'WebBrowser', - 'AsyncWebBrowser', - 'BaseParser', - 'JsonParser', - 'TupleParser', - 'tool_api', - 'AsyncActionMixin', + 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction', + 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch', + 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar', + 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter', + 'IPythonInteractive', 'AsyncIPythonInteractive', + 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter', + 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser', + 'JsonParser', 'TupleParser', 'tool_api' ] +from .weather_query import WeatherQuery +__all__ = [ + 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction', + 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch', + 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar', + 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter', + 'IPythonInteractive', 'AsyncIPythonInteractive', + 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter', + 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser', + 'JsonParser', 'TupleParser', 'tool_api', 'WeatherQuery' # 这里 +] \ No newline at end of file diff --git a/lagent/actions/__pycache__/__init__.cpython-310.pyc b/lagent/actions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc12c912731722388203a08f96065f616b36f470 Binary files /dev/null and b/lagent/actions/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/action_executor.cpython-310.pyc b/lagent/actions/__pycache__/action_executor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c34b06993a346fbd659671b4ffff135305e03497 Binary files /dev/null and b/lagent/actions/__pycache__/action_executor.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/arxiv_search.cpython-310.pyc b/lagent/actions/__pycache__/arxiv_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5279b31cc8727162a43b4a57862bca718e82be13 Binary files /dev/null and b/lagent/actions/__pycache__/arxiv_search.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/base_action.cpython-310.pyc b/lagent/actions/__pycache__/base_action.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9503b7b3795df5ef503ea6d5cf9d5bdb5d63cc59 Binary files /dev/null and b/lagent/actions/__pycache__/base_action.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/bing_map.cpython-310.pyc b/lagent/actions/__pycache__/bing_map.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c349a45d606d111a04f5971324a7dbe03f73bad Binary files /dev/null and b/lagent/actions/__pycache__/bing_map.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/builtin_actions.cpython-310.pyc b/lagent/actions/__pycache__/builtin_actions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb08aa0d91d1e454bbeb99a723176f792d2253a Binary files /dev/null and b/lagent/actions/__pycache__/builtin_actions.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc b/lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c29d76d88ea54a4d1e980c96d7cf9275243b91d5 Binary files /dev/null and b/lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/google_search.cpython-310.pyc b/lagent/actions/__pycache__/google_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b92ee94abffa9a30a5ddb1a632e7e79eccd58c19 Binary files /dev/null and b/lagent/actions/__pycache__/google_search.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc b/lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a936a478e2ba5526c4c523c7f9f550388b41930 Binary files /dev/null and b/lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc b/lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1742c8424d721d249a1f4deb29e060c7bab928a Binary files /dev/null and b/lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/ipython_manager.cpython-310.pyc b/lagent/actions/__pycache__/ipython_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d081556c32b1c9e30e7b86ac28543258ea522db1 Binary files /dev/null and b/lagent/actions/__pycache__/ipython_manager.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/parser.cpython-310.pyc b/lagent/actions/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3117dcae5c085b33f2f9c55b030574a11757cdc1 Binary files /dev/null and b/lagent/actions/__pycache__/parser.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/ppt.cpython-310.pyc b/lagent/actions/__pycache__/ppt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9477630b9b91bae41d1ce0605e18de64b2bc498a Binary files /dev/null and b/lagent/actions/__pycache__/ppt.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/python_interpreter.cpython-310.pyc b/lagent/actions/__pycache__/python_interpreter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0062b3c86a16040847df1c841a6acd5516ad57d Binary files /dev/null and b/lagent/actions/__pycache__/python_interpreter.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/weather_query.cpython-310.pyc b/lagent/actions/__pycache__/weather_query.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53bfbf6c9a5dc4760a2071da5ebd42023827113f Binary files /dev/null and b/lagent/actions/__pycache__/weather_query.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/web_browser.cpython-310.pyc b/lagent/actions/__pycache__/web_browser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f18cf11d7d3d0615aacfd298b42eee1982515ea9 Binary files /dev/null and b/lagent/actions/__pycache__/web_browser.cpython-310.pyc differ diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index b42036af012376dbd2623458a49245cb40b65baf..2e546f75bb251c73991fee124c00a222f53bbce5 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -4,7 +4,7 @@ import re from abc import ABCMeta from copy import deepcopy from functools import wraps -from typing import Callable, Iterable, Optional, Type, get_args, get_origin +from typing import Callable, Optional, Type, get_args, get_origin try: from typing import Annotated @@ -24,15 +24,11 @@ from .parser import BaseParser, JsonParser, ParseError logging.getLogger('griffe').setLevel(logging.ERROR) -def tool_api( - func: Optional[Callable] = None, - *, - explode_return: bool = False, - returns_named_value: bool = False, - include_arguments: Optional[Iterable[str]] = None, - exclude_arguments: Optional[Iterable[str]] = None, - **kwargs, -): +def tool_api(func: Optional[Callable] = None, + *, + explode_return: bool = False, + returns_named_value: bool = False, + **kwargs): """Turn functions into tools. It will parse typehints as well as docstrings to build the tool description and attach it to functions via an attribute ``api_description``. @@ -94,16 +90,6 @@ def tool_api( ``return_data`` field will be added to ``api_description`` only when ``explode_return`` or ``returns_named_value`` is enabled. """ - if include_arguments is None: - exclude_arguments = exclude_arguments or set() - if isinstance(exclude_arguments, str): - exclude_arguments = {exclude_arguments} - elif not isinstance(exclude_arguments, set): - exclude_arguments = set(exclude_arguments) - if 'self' not in exclude_arguments: - exclude_arguments.add('self') - else: - include_arguments = {include_arguments} if isinstance(include_arguments, str) else set(include_arguments) def _detect_type(string): field_type = 'STRING' @@ -120,9 +106,10 @@ def tool_api( def _explode(desc): kvs = [] - desc = '\nArgs:\n' + '\n'.join( - [' ' + item.lstrip(' -+*#.') for item in desc.split('\n')[1:] if item.strip()] - ) + desc = '\nArgs:\n' + '\n'.join([ + ' ' + item.lstrip(' -+*#.') + for item in desc.split('\n')[1:] if item.strip() + ]) docs = Docstring(desc).parse('google') if not docs: return kvs @@ -138,12 +125,13 @@ def tool_api( def _parse_tool(function): # remove rst syntax - docs = Docstring(re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( - 'google', returns_named_value=returns_named_value, **kwargs - ) + docs = Docstring( + re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( + 'google', returns_named_value=returns_named_value, **kwargs) desc = dict( name=function.__name__, - description=docs[0].value if docs[0].kind is DocstringSectionKind.text else '', + description=docs[0].value + if docs[0].kind is DocstringSectionKind.text else '', parameters=[], required=[], ) @@ -167,14 +155,17 @@ def tool_api( sig = inspect.signature(function) for name, param in sig.parameters.items(): - if name in exclude_arguments if include_arguments is None else name not in include_arguments: + if name == 'self': continue parameter = dict( - name=param.name, type='STRING', description=args_doc.get(param.name, {}).get('description', '') - ) + name=param.name, + type='STRING', + description=args_doc.get(param.name, + {}).get('description', '')) annotation = param.annotation if annotation is inspect.Signature.empty: - parameter['type'] = args_doc.get(param.name, {}).get('type', 'STRING') + parameter['type'] = args_doc.get(param.name, + {}).get('type', 'STRING') else: if get_origin(annotation) is Annotated: annotation, info = get_args(annotation) @@ -238,8 +229,9 @@ class ToolMeta(ABCMeta): def __new__(mcs, name, base, attrs): is_toolkit, tool_desc = True, dict( - name=name, description=Docstring(attrs.get('__doc__', '')).parse('google')[0].value - ) + name=name, + description=Docstring(attrs.get('__doc__', + '')).parse('google')[0].value) for key, value in attrs.items(): if callable(value) and hasattr(value, 'api_description'): api_desc = getattr(value, 'api_description') @@ -254,7 +246,8 @@ class ToolMeta(ABCMeta): else: tool_desc.setdefault('api_list', []).append(api_desc) if not is_toolkit and 'api_list' in tool_desc: - raise KeyError('`run` and other tool APIs can not be implemented ' 'at the same time') + raise KeyError('`run` and other tool APIs can not be implemented ' + 'at the same time') if is_toolkit and 'api_list' not in tool_desc: is_toolkit = False if callable(attrs.get('run')): @@ -353,16 +346,26 @@ class BaseAction(metaclass=ToolMeta): fallback_args = {'inputs': inputs, 'name': name} if not hasattr(self, name): return ActionReturn( - fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR - ) + fallback_args, + type=self.name, + errmsg=f'invalid API: {name}', + state=ActionStatusCode.API_ERROR) try: inputs = self._parser.parse_inputs(inputs, name) except ParseError as exc: - return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) + return ActionReturn( + fallback_args, + type=self.name, + errmsg=exc.err_msg, + state=ActionStatusCode.ARGS_ERROR) try: outputs = getattr(self, name)(**inputs) except Exception as exc: - return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) + return ActionReturn( + inputs, + type=self.name, + errmsg=str(exc), + state=ActionStatusCode.API_ERROR) if isinstance(outputs, ActionReturn): action_return = outputs if not action_return.args: @@ -399,16 +402,26 @@ class AsyncActionMixin: fallback_args = {'inputs': inputs, 'name': name} if not hasattr(self, name): return ActionReturn( - fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR - ) + fallback_args, + type=self.name, + errmsg=f'invalid API: {name}', + state=ActionStatusCode.API_ERROR) try: inputs = self._parser.parse_inputs(inputs, name) except ParseError as exc: - return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) + return ActionReturn( + fallback_args, + type=self.name, + errmsg=exc.err_msg, + state=ActionStatusCode.ARGS_ERROR) try: outputs = await getattr(self, name)(**inputs) except Exception as exc: - return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) + return ActionReturn( + inputs, + type=self.name, + errmsg=str(exc), + state=ActionStatusCode.API_ERROR) if isinstance(outputs, ActionReturn): action_return = outputs if not action_return.args: diff --git a/lagent/actions/weather_query.py b/lagent/actions/weather_query.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe3e991dbca34e0a6d373d62d457c7237317741 --- /dev/null +++ b/lagent/actions/weather_query.py @@ -0,0 +1,71 @@ +import os +import requests +from lagent.actions.base_action import BaseAction, tool_api +from lagent.schema import ActionReturn, ActionStatusCode + +class WeatherQuery(BaseAction): + def __init__(self): + super().__init__() + self.api_key = os.getenv("weather_token") + print(self.api_key) + if not self.api_key: + raise EnvironmentError("未找到环境变量 'token'。请设置你的和风天气 API Key 到 'weather_token' 环境变量中,比如export weather_token='xxx' ") + + @tool_api + def run(self, location: str) -> dict: + """ + 查询实时天气信息。 + + Args: + location (str): 要查询的地点名称、LocationID 或经纬度坐标(如 "101010100" 或 "116.41,39.92")。 + + Returns: + dict: 包含天气信息的字典 + * location: 地点名称 + * weather: 天气状况 + * temperature: 当前温度 + * wind_direction: 风向 + * wind_speed: 风速(公里/小时) + * humidity: 相对湿度(%) + * report_time: 数据报告时间 + """ + try: + # 如果 location 不是坐标格式(例如 "116.41,39.92"),则调用 GeoAPI 获取 LocationID + if not ("," in location and location.replace(",", "").replace(".", "").isdigit()): + # 使用 GeoAPI 获取 LocationID + geo_url = f"https://geoapi.qweather.com/v2/city/lookup?location={location}&key={self.api_key}" + geo_response = requests.get(geo_url) + geo_data = geo_response.json() + + if geo_data.get("code") != "200" or not geo_data.get("location"): + raise Exception(f"GeoAPI 返回错误码:{geo_data.get('code')} 或未找到位置") + + location = geo_data["location"][0]["id"] + + # 构建天气查询的 API 请求 URL + weather_url = f"https://devapi.qweather.com/v7/weather/now?location={location}&key={self.api_key}" + response = requests.get(weather_url) + data = response.json() + + # 检查 API 响应码 + if data.get("code") != "200": + raise Exception(f"Weather API 返回错误码:{data.get('code')}") + + # 解析和组织天气信息 + weather_info = { + "location": location, + "weather": data["now"]["text"], + "temperature": data["now"]["temp"] + "°C", + "wind_direction": data["now"]["windDir"], + "wind_speed": data["now"]["windSpeed"] + " km/h", + "humidity": data["now"]["humidity"] + "%", + "report_time": data["updateTime"] + } + + return {"result": weather_info} + + except Exception as exc: + return ActionReturn( + errmsg=f"WeatherQuery 异常:{exc}", + state=ActionStatusCode.HTTP_ERROR + ) \ No newline at end of file diff --git a/lagent/actions/web_browser.py b/lagent/actions/web_browser.py index 29f7594c62d534553321e1fe70d917ee9c34a593..432fe27eee69175220aea506b415278ed97ea767 100644 --- a/lagent/actions/web_browser.py +++ b/lagent/actions/web_browser.py @@ -18,6 +18,7 @@ import requests from asyncache import cached as acached from bs4 import BeautifulSoup from cachetools import TTLCache, cached +from duckduckgo_search import DDGS, AsyncDDGS from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api from lagent.actions.parser import BaseParser, JsonParser @@ -34,11 +35,12 @@ class BaseSearch: filtered_results = {} count = 0 for url, snippet, title in results: - if all(domain not in url for domain in self.black_list) and not url.endswith('.pdf'): + if all(domain not in url + for domain in self.black_list) and not url.endswith('.pdf'): filtered_results[count] = { 'url': url, 'summ': json.dumps(snippet, ensure_ascii=False)[1:-1], - 'title': title, + 'title': title } count += 1 if count >= self.topk: @@ -48,17 +50,15 @@ class BaseSearch: class DuckDuckGoSearch(BaseSearch): - def __init__( - self, - topk: int = 3, - black_list: List[str] = [ - 'enoN', - 'youtube.com', - 'bilibili.com', - 'researchgate.net', - ], - **kwargs, - ): + def __init__(self, + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): self.proxy = kwargs.get('proxy') self.timeout = kwargs.get('timeout', 30) super().__init__(topk, black_list) @@ -67,39 +67,40 @@ class DuckDuckGoSearch(BaseSearch): def search(self, query: str, max_retry: int = 3) -> dict: for attempt in range(max_retry): try: - response = self._call_ddgs(query, timeout=self.timeout, proxy=self.proxy) + response = self._call_ddgs( + query, timeout=self.timeout, proxy=self.proxy) return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') time.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from DuckDuckGo after retries.') + raise Exception( + 'Failed to get search results from DuckDuckGo after retries.') @acached(cache=TTLCache(maxsize=100, ttl=600)) async def asearch(self, query: str, max_retry: int = 3) -> dict: - from duckduckgo_search import AsyncDDGS - for attempt in range(max_retry): try: ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy) - response = await ddgs.text(query.strip("'"), max_results=10) + response = await ddgs.atext(query.strip("'"), max_results=10) return self._parse_response(response) except Exception as e: if isinstance(e, asyncio.TimeoutError): logging.exception('Request to DDGS timed out.') logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') await asyncio.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from DuckDuckGo after retries.') + raise Exception( + 'Failed to get search results from DuckDuckGo after retries.') async def _async_call_ddgs(self, query: str, **kwargs) -> dict: - from duckduckgo_search import DDGS - ddgs = DDGS(**kwargs) try: response = await asyncio.wait_for( - asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10), timeout=self.timeout - ) + asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10), + timeout=self.timeout) return response except asyncio.TimeoutError: logging.exception('Request to DDGS timed out.') @@ -109,35 +110,34 @@ class DuckDuckGoSearch(BaseSearch): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - response = loop.run_until_complete(self._async_call_ddgs(query, **kwargs)) + response = loop.run_until_complete( + self._async_call_ddgs(query, **kwargs)) return response finally: loop.close() - def _parse_response(self, response: List[dict]) -> dict: + def _parse_response(self, response: dict) -> dict: raw_results = [] for item in response: raw_results.append( - (item['href'], item['description'] if 'description' in item else item['body'], item['title']) - ) + (item['href'], item['description'] + if 'description' in item else item['body'], item['title'])) return self._filter_results(raw_results) class BingSearch(BaseSearch): - def __init__( - self, - api_key: str, - region: str = 'zh-CN', - topk: int = 3, - black_list: List[str] = [ - 'enoN', - 'youtube.com', - 'bilibili.com', - 'researchgate.net', - ], - **kwargs, - ): + def __init__(self, + api_key: str, + region: str = 'zh-CN', + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): self.api_key = api_key self.market = region self.proxy = kwargs.get('proxy') @@ -151,9 +151,11 @@ class BingSearch(BaseSearch): return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') time.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from Bing Search after retries.') + raise Exception( + 'Failed to get search results from Bing Search after retries.') @acached(cache=TTLCache(maxsize=100, ttl=600)) async def asearch(self, query: str, max_retry: int = 3) -> dict: @@ -163,15 +165,18 @@ class BingSearch(BaseSearch): return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') await asyncio.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from Bing Search after retries.') + raise Exception( + 'Failed to get search results from Bing Search after retries.') def _call_bing_api(self, query: str) -> dict: endpoint = 'https://api.bing.microsoft.com/v7.0/search' params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'} headers = {'Ocp-Apim-Subscription-Key': self.api_key} - response = requests.get(endpoint, headers=headers, params=params, proxies=self.proxy) + response = requests.get( + endpoint, headers=headers, params=params, proxies=self.proxy) response.raise_for_status() return response.json() @@ -181,25 +186,32 @@ class BingSearch(BaseSearch): headers = {'Ocp-Apim-Subscription-Key': self.api_key} async with aiohttp.ClientSession(raise_for_status=True) as session: async with session.get( - endpoint, - headers=headers, - params=params, - proxy=self.proxy and (self.proxy.get('http') or self.proxy.get('https')), - ) as resp: + endpoint, + headers=headers, + params=params, + proxy=self.proxy and + (self.proxy.get('http') or self.proxy.get('https'))) as resp: return await resp.json() def _parse_response(self, response: dict) -> dict: - webpages = {w['id']: w for w in response.get('webPages', {}).get('value', [])} + webpages = { + w['id']: w + for w in response.get('webPages', {}).get('value', []) + } raw_results = [] - for item in response.get('rankingResponse', {}).get('mainline', {}).get('items', []): + for item in response.get('rankingResponse', + {}).get('mainline', {}).get('items', []): if item['answerType'] == 'WebPages': webpage = webpages.get(item['value']['id']) if webpage: - raw_results.append((webpage['url'], webpage['snippet'], webpage['name'])) - elif item['answerType'] == 'News' and item['value']['id'] == response.get('news', {}).get('id'): + raw_results.append( + (webpage['url'], webpage['snippet'], webpage['name'])) + elif item['answerType'] == 'News' and item['value'][ + 'id'] == response.get('news', {}).get('id'): for news in response.get('news', {}).get('value', []): - raw_results.append((news['url'], news['description'], news['name'])) + raw_results.append( + (news['url'], news['description'], news['name'])) return self._filter_results(raw_results) @@ -218,27 +230,24 @@ class BraveSearch(BaseSearch): topk (int): The number of search results returned in response from API search results. region (str): The country code string. Specifies the country where the search results come from. language (str): The language code string. Specifies the preferred language for the search results. - extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the - search results. + extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results. **kwargs: Any other parameters related to the Brave Search API. Find more details at https://api.search.brave.com/app/documentation/web-search/get-started. """ - def __init__( - self, - api_key: str, - region: str = 'ALL', - language: str = 'zh-hans', - extra_snippests: bool = True, - topk: int = 3, - black_list: List[str] = [ - 'enoN', - 'youtube.com', - 'bilibili.com', - 'researchgate.net', - ], - **kwargs, - ): + def __init__(self, + api_key: str, + region: str = 'ALL', + language: str = 'zh-hans', + extra_snippests: bool = True, + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): self.api_key = api_key self.market = region self.proxy = kwargs.get('proxy') @@ -256,9 +265,11 @@ class BraveSearch(BaseSearch): return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') time.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from Brave Search after retries.') + raise Exception( + 'Failed to get search results from Brave Search after retries.') @acached(cache=TTLCache(maxsize=100, ttl=600)) async def asearch(self, query: str, max_retry: int = 3) -> dict: @@ -268,9 +279,11 @@ class BraveSearch(BaseSearch): return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') await asyncio.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from Brave Search after retries.') + raise Exception( + 'Failed to get search results from Brave Search after retries.') def _call_brave_api(self, query: str) -> dict: endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search' @@ -280,10 +293,17 @@ class BraveSearch(BaseSearch): 'search_lang': self.language, 'extra_snippets': self.extra_snippests, 'count': self.topk, - **{key: value for key, value in self.kwargs.items() if value is not None}, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, } - headers = {'X-Subscription-Token': self.api_key or '', 'Accept': 'application/json'} - response = requests.get(endpoint, headers=headers, params=params, proxies=self.proxy) + headers = { + 'X-Subscription-Token': self.api_key or '', + 'Accept': 'application/json' + } + response = requests.get( + endpoint, headers=headers, params=params, proxies=self.proxy) response.raise_for_status() return response.json() @@ -295,16 +315,22 @@ class BraveSearch(BaseSearch): 'search_lang': self.language, 'extra_snippets': self.extra_snippests, 'count': self.topk, - **{key: value for key, value in self.kwargs.items() if value is not None}, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, + } + headers = { + 'X-Subscription-Token': self.api_key or '', + 'Accept': 'application/json' } - headers = {'X-Subscription-Token': self.api_key or '', 'Accept': 'application/json'} async with aiohttp.ClientSession(raise_for_status=True) as session: async with session.get( - endpoint, - headers=headers, - params=params, - proxy=self.proxy and (self.proxy.get('http') or self.proxy.get('https')), - ) as resp: + endpoint, + headers=headers, + params=params, + proxy=self.proxy and + (self.proxy.get('http') or self.proxy.get('https'))) as resp: return await resp.json() def _parse_response(self, response: dict) -> dict: @@ -315,13 +341,15 @@ class BraveSearch(BaseSearch): raw_results = [] for item in filtered_result: - raw_results.append( - ( - item.get('url', ''), - ' '.join(filter(None, [item.get('description'), *item.get('extra_snippets', [])])), - item.get('title', ''), - ) - ) + raw_results.append(( + item.get('url', ''), + ' '.join( + filter(None, [ + item.get('description'), + *item.get('extra_snippets', []) + ])), + item.get('title', ''), + )) return self._filter_results(raw_results) @@ -348,18 +376,16 @@ class GoogleSearch(BaseSearch): 'search': 'organic', } - def __init__( - self, - api_key: str, - topk: int = 3, - black_list: List[str] = [ - 'enoN', - 'youtube.com', - 'bilibili.com', - 'researchgate.net', - ], - **kwargs, - ): + def __init__(self, + api_key: str, + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): self.api_key = api_key self.proxy = kwargs.get('proxy') self.search_type = kwargs.get('search_type', 'search') @@ -374,9 +400,12 @@ class GoogleSearch(BaseSearch): return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') time.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from Google Serper Search after retries.') + raise Exception( + 'Failed to get search results from Google Serper Search after retries.' + ) @acached(cache=TTLCache(maxsize=100, ttl=600)) async def asearch(self, query: str, max_retry: int = 3) -> dict: @@ -386,19 +415,29 @@ class GoogleSearch(BaseSearch): return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') await asyncio.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from Google Serper Search after retries.') + raise Exception( + 'Failed to get search results from Google Serper Search after retries.' + ) def _call_serper_api(self, query: str) -> dict: endpoint = f'https://google.serper.dev/{self.search_type}' params = { 'q': query, 'num': self.topk, - **{key: value for key, value in self.kwargs.items() if value is not None}, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, } - headers = {'X-API-KEY': self.api_key or '', 'Content-Type': 'application/json'} - response = requests.get(endpoint, headers=headers, params=params, proxies=self.proxy) + headers = { + 'X-API-KEY': self.api_key or '', + 'Content-Type': 'application/json' + } + response = requests.get( + endpoint, headers=headers, params=params, proxies=self.proxy) response.raise_for_status() return response.json() @@ -407,16 +446,22 @@ class GoogleSearch(BaseSearch): params = { 'q': query, 'num': self.topk, - **{key: value for key, value in self.kwargs.items() if value is not None}, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, + } + headers = { + 'X-API-KEY': self.api_key or '', + 'Content-Type': 'application/json' } - headers = {'X-API-KEY': self.api_key or '', 'Content-Type': 'application/json'} async with aiohttp.ClientSession(raise_for_status=True) as session: async with session.get( - endpoint, - headers=headers, - params=params, - proxy=self.proxy and (self.proxy.get('http') or self.proxy.get('https')), - ) as resp: + endpoint, + headers=headers, + params=params, + proxy=self.proxy and + (self.proxy.get('http') or self.proxy.get('https'))) as resp: return await resp.json() def _parse_response(self, response: dict) -> dict: @@ -427,34 +472,33 @@ class GoogleSearch(BaseSearch): if answer_box.get('answer'): raw_results.append(('', answer_box.get('answer'), '')) elif answer_box.get('snippet'): - raw_results.append(('', answer_box.get('snippet').replace('\n', ' '), '')) + raw_results.append( + ('', answer_box.get('snippet').replace('\n', ' '), '')) elif answer_box.get('snippetHighlighted'): - raw_results.append(('', answer_box.get('snippetHighlighted'), '')) + raw_results.append( + ('', answer_box.get('snippetHighlighted'), '')) if response.get('knowledgeGraph'): kg = response.get('knowledgeGraph', {}) description = kg.get('description', '') - attributes = '. '.join(f'{attribute}: {value}' for attribute, value in kg.get('attributes', {}).items()) + attributes = '. '.join( + f'{attribute}: {value}' + for attribute, value in kg.get('attributes', {}).items()) raw_results.append( - ( - kg.get('descriptionLink', ''), - f'{description}. {attributes}' if attributes else description, - f"{kg.get('title', '')}: {kg.get('type', '')}.", - ) - ) - - for result in response[self.result_key_for_type[self.search_type]][: self.topk]: + (kg.get('descriptionLink', ''), + f'{description}. {attributes}' if attributes else description, + f"{kg.get('title', '')}: {kg.get('type', '')}.")) + + for result in response[self.result_key_for_type[ + self.search_type]][:self.topk]: description = result.get('snippet', '') attributes = '. '.join( - f'{attribute}: {value}' for attribute, value in result.get('attributes', {}).items() - ) + f'{attribute}: {value}' + for attribute, value in result.get('attributes', {}).items()) raw_results.append( - ( - result.get('link', ''), - f'{description}. {attributes}' if attributes else description, - result.get('title', ''), - ) - ) + (result.get('link', ''), + f'{description}. {attributes}' if attributes else description, + result.get('title', ''))) return self._filter_results(raw_results) @@ -485,27 +529,25 @@ class TencentSearch(BaseSearch): Supports multiple values separated by commas. Example: `30010255`. """ - def __init__( - self, - secret_id: str = 'Your SecretId', - secret_key: str = 'Your SecretKey', - api_key: str = '', - action: str = 'SearchCommon', - version: str = '2020-12-29', - service: str = 'tms', - host: str = 'tms.tencentcloudapi.com', - topk: int = 3, - tsn: int = None, - insite: str = None, - category: str = None, - vrid: str = None, - black_list: List[str] = [ - 'enoN', - 'youtube.com', - 'bilibili.com', - 'researchgate.net', - ], - ): + def __init__(self, + secret_id: str = 'Your SecretId', + secret_key: str = 'Your SecretKey', + api_key: str = '', + action: str = 'SearchCommon', + version: str = '2020-12-29', + service: str = 'tms', + host: str = 'tms.tencentcloudapi.com', + topk: int = 3, + tsn: int = None, + insite: str = None, + category: str = None, + vrid: str = None, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ]): self.secret_id = secret_id self.secret_key = secret_key self.api_key = api_key @@ -527,9 +569,11 @@ class TencentSearch(BaseSearch): return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') time.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from Bing Search after retries.') + raise Exception( + 'Failed to get search results from Bing Search after retries.') @acached(cache=TTLCache(maxsize=100, ttl=600)) async def asearch(self, query: str, max_retry: int = 3) -> dict: @@ -539,9 +583,11 @@ class TencentSearch(BaseSearch): return self._parse_response(response) except Exception as e: logging.exception(str(e)) - warnings.warn(f'Retry {attempt + 1}/{max_retry} due to error: {e}') + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') await asyncio.sleep(random.randint(2, 5)) - raise Exception('Failed to get search results from Bing Search after retries.') + raise Exception( + 'Failed to get search results from Bing Search after retries.') def _get_headers_and_payload(self, query: str) -> tuple: @@ -571,47 +617,33 @@ class TencentSearch(BaseSearch): ct = 'application/json; charset=utf-8' canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n' signed_headers = 'content-type;host;x-tc-action' - hashed_request_payload = hashlib.sha256(payload.encode('utf-8')).hexdigest() + hashed_request_payload = hashlib.sha256( + payload.encode('utf-8')).hexdigest() canonical_request = ( - http_request_method - + '\n' - + canonical_uri - + '\n' - + canonical_querystring - + '\n' - + canonical_headers - + '\n' - + signed_headers - + '\n' - + hashed_request_payload - ) + http_request_method + '\n' + canonical_uri + '\n' + + canonical_querystring + '\n' + canonical_headers + '\n' + + signed_headers + '\n' + hashed_request_payload) # ************* 步骤 2:拼接待签名字符串 ************* credential_scope = date + '/' + self.service + '/' + 'tc3_request' - hashed_canonical_request = hashlib.sha256(canonical_request.encode('utf-8')).hexdigest() - string_to_sign = algorithm + '\n' + str(timestamp) + '\n' + credential_scope + '\n' + hashed_canonical_request + hashed_canonical_request = hashlib.sha256( + canonical_request.encode('utf-8')).hexdigest() + string_to_sign = ( + algorithm + '\n' + str(timestamp) + '\n' + credential_scope + + '\n' + hashed_canonical_request) # ************* 步骤 3:计算签名 ************* secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date) secret_service = sign(secret_date, self.service) secret_signing = sign(secret_service, 'tc3_request') - signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest() + signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'), + hashlib.sha256).hexdigest() # ************* 步骤 4:拼接 Authorization ************* authorization = ( - algorithm - + ' ' - + 'Credential=' - + self.secret_id - + '/' - + credential_scope - + ', ' - + 'SignedHeaders=' - + signed_headers - + ', ' - + 'Signature=' - + signature - ) + algorithm + ' ' + 'Credential=' + self.secret_id + '/' + + credential_scope + ', ' + 'SignedHeaders=' + signed_headers + + ', ' + 'Signature=' + signature) # ************* 步骤 5:构造并发起请求 ************* headers = { @@ -620,7 +652,7 @@ class TencentSearch(BaseSearch): 'Host': self.host, 'X-TC-Action': self.action, 'X-TC-Timestamp': str(timestamp), - 'X-TC-Version': self.version, + 'X-TC-Version': self.version } # if self.region: # headers["X-TC-Region"] = self.region @@ -638,14 +670,16 @@ class TencentSearch(BaseSearch): except Exception as e: logging.warning(str(e)) import ast - resp = ast.literal_eval(resp) return resp.get('Response', dict()) async def _async_call_tencent_api(self, query: str): headers, payload = self._get_headers_and_payload(query) async with aiohttp.ClientSession(raise_for_status=True) as session: - async with session.post('https://' + self.host.lstrip('/'), headers=headers, data=payload) as resp: + async with session.post( + 'https://' + self.host.lstrip('/'), + headers=headers, + data=payload) as resp: return (await resp.json()).get('Response', {}) def _parse_response(self, response: dict) -> dict: @@ -654,7 +688,8 @@ class TencentSearch(BaseSearch): display = json.loads(item['Display']) if not display['url']: continue - raw_results.append((display['url'], display['content'] or display['abstract_info'], display['title'])) + raw_results.append((display['url'], display['content'] + or display['abstract_info'], display['title'])) return self._filter_results(raw_results) @@ -680,8 +715,8 @@ class ContentFetcher: async def afetch(self, url: str) -> Tuple[bool, str]: try: async with aiohttp.ClientSession( - raise_for_status=True, timeout=aiohttp.ClientTimeout(self.timeout) - ) as session: + raise_for_status=True, + timeout=aiohttp.ClientTimeout(self.timeout)) as session: async with session.get(url) as resp: html = await resp.text(errors='ignore') text = BeautifulSoup(html, 'html.parser').get_text() @@ -692,24 +727,24 @@ class ContentFetcher: class WebBrowser(BaseAction): - """Wrapper around the Web Browser Tool.""" - - def __init__( - self, - searcher_type: str = 'DuckDuckGoSearch', - timeout: int = 5, - black_list: Optional[List[str]] = [ - 'enoN', - 'youtube.com', - 'bilibili.com', - 'researchgate.net', - ], - topk: int = 20, - description: Optional[dict] = None, - parser: Type[BaseParser] = JsonParser, - **kwargs, - ): - self.searcher = eval(searcher_type)(black_list=black_list, topk=topk, **kwargs) + """Wrapper around the Web Browser Tool. + """ + + def __init__(self, + searcher_type: str = 'DuckDuckGoSearch', + timeout: int = 5, + black_list: Optional[List[str]] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + topk: int = 20, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + **kwargs): + self.searcher = eval(searcher_type)( + black_list=black_list, topk=topk, **kwargs) self.fetcher = ContentFetcher(timeout=timeout) self.search_results = None super().__init__(description, parser) @@ -724,7 +759,10 @@ class WebBrowser(BaseAction): search_results = {} with ThreadPoolExecutor() as executor: - future_to_query = {executor.submit(self.searcher.search, q): q for q in queries} + future_to_query = { + executor.submit(self.searcher.search, q): q + for q in queries + } for future in as_completed(future_to_query): query = future_to_query[future] @@ -737,9 +775,13 @@ class WebBrowser(BaseAction): if result['url'] not in search_results: search_results[result['url']] = result else: - search_results[result['url']]['summ'] += f"\n{result['summ']}" + search_results[ + result['url']]['summ'] += f"\n{result['summ']}" - self.search_results = {idx: result for idx, result in enumerate(search_results.values())} + self.search_results = { + idx: result + for idx, result in enumerate(search_results.values()) + } return self.search_results @tool_api @@ -756,8 +798,7 @@ class WebBrowser(BaseAction): with ThreadPoolExecutor() as executor: future_to_id = { executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id - for select_id in select_ids - if select_id in self.search_results + for select_id in select_ids if select_id in self.search_results } for future in as_completed(future_to_id): select_id = future_to_id[future] @@ -767,8 +808,10 @@ class WebBrowser(BaseAction): warnings.warn(f'{select_id} generated an exception: {exc}') else: if web_success: - self.search_results[select_id]['content'] = web_content[:8192] - new_search_results[select_id] = self.search_results[select_id].copy() + self.search_results[select_id][ + 'content'] = web_content[:8192] + new_search_results[select_id] = self.search_results[ + select_id].copy() new_search_results[select_id].pop('summ') return new_search_results @@ -784,12 +827,13 @@ class WebBrowser(BaseAction): class AsyncWebBrowser(AsyncActionMixin, WebBrowser): - """Wrapper around the Web Browser Tool.""" + """Wrapper around the Web Browser Tool. + """ @tool_api async def search(self, query: Union[str, List[str]]) -> dict: """BING search API - + Args: query (List[str]): list of search query strings """ @@ -812,9 +856,13 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser): if result['url'] not in search_results: search_results[result['url']] = result else: - search_results[result['url']]['summ'] += f"\n{result['summ']}" + search_results[ + result['url']]['summ'] += f"\n{result['summ']}" - self.search_results = {idx: result for idx, result in enumerate(search_results.values())} + self.search_results = { + idx: result + for idx, result in enumerate(search_results.values()) + } return self.search_results @tool_api @@ -831,7 +879,8 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser): tasks = [] for select_id in select_ids: if select_id in self.search_results: - task = asyncio.create_task(self.fetcher.afetch(self.search_results[select_id]['url'])) + task = asyncio.create_task( + self.fetcher.afetch(self.search_results[select_id]['url'])) task.select_id = select_id tasks.append(task) async for future in async_as_completed(tasks): @@ -842,8 +891,10 @@ class AsyncWebBrowser(AsyncActionMixin, WebBrowser): warnings.warn(f'{select_id} generated an exception: {exc}') else: if web_success: - self.search_results[select_id]['content'] = web_content[:8192] - new_search_results[select_id] = self.search_results[select_id].copy() + self.search_results[select_id][ + 'content'] = web_content[:8192] + new_search_results[select_id] = self.search_results[ + select_id].copy() new_search_results[select_id].pop('summ') return new_search_results diff --git a/lagent/agents/__init__.py b/lagent/agents/__init__.py index 0a995d2883ffc5a2eca5eb26076def19f13137f7..f06972cc56e13012e8fe54a9fe8764748ae93f43 100644 --- a/lagent/agents/__init__.py +++ b/lagent/agents/__init__.py @@ -1,33 +1,9 @@ -from .agent import ( - Agent, - AgentDict, - AgentList, - AsyncAgent, - AsyncSequential, - AsyncStreamingAgent, - AsyncStreamingSequential, - Sequential, - StreamingAgent, - StreamingSequential, -) +from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential from .react import AsyncReAct, ReAct from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder __all__ = [ - 'Agent', - 'AgentDict', - 'AgentList', - 'AsyncAgent', - 'AgentForInternLM', - 'AsyncAgentForInternLM', - 'MathCoder', - 'AsyncMathCoder', - 'ReAct', - 'AsyncReAct', - 'Sequential', - 'AsyncSequential', - 'StreamingAgent', - 'StreamingSequential', - 'AsyncStreamingAgent', - 'AsyncStreamingSequential', + 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM', + 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct', + 'AsyncReAct', 'Sequential', 'AsyncSequential' ] diff --git a/lagent/agents/__pycache__/__init__.cpython-310.pyc b/lagent/agents/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c391174cc454ebffdcc718f5c49f018e727afa3 Binary files /dev/null and b/lagent/agents/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/agents/__pycache__/agent.cpython-310.pyc b/lagent/agents/__pycache__/agent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60470c6d1a7df89a740f9dac4bf7526cc872e42c Binary files /dev/null and b/lagent/agents/__pycache__/agent.cpython-310.pyc differ diff --git a/lagent/agents/__pycache__/react.cpython-310.pyc b/lagent/agents/__pycache__/react.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3d94a7e2d405e012d9a4f0bfd534b425b0c5daa Binary files /dev/null and b/lagent/agents/__pycache__/react.cpython-310.pyc differ diff --git a/lagent/agents/__pycache__/stream.cpython-310.pyc b/lagent/agents/__pycache__/stream.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9560dad57605f1f9b25c10954452243f5b455526 Binary files /dev/null and b/lagent/agents/__pycache__/stream.cpython-310.pyc differ diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py index 9707d7bf98d7eb00e8e5cf4fed2854e984107a30..b1e941baa442a52deb37755f64002724316bcf08 100644 --- a/lagent/agents/agent.py +++ b/lagent/agents/agent.py @@ -3,7 +3,7 @@ import warnings from collections import OrderedDict, UserDict, UserList, abc from functools import wraps from itertools import chain, repeat -from typing import Any, AsyncGenerator, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from lagent.agents.aggregator import DefaultAggregator from lagent.hooks import Hook, RemovableHandle @@ -11,7 +11,7 @@ from lagent.llms import BaseLLM from lagent.memory import Memory, MemoryManager from lagent.prompts.parsers import StrParser from lagent.prompts.prompt_template import PromptTemplate -from lagent.schema import AgentMessage, ModelStatusCode +from lagent.schema import AgentMessage from lagent.utils import create_object @@ -63,17 +63,29 @@ class Agent: if self.memory: self.memory.add(message, session_id=session_id) - def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: + def __call__( + self, + *message: Union[str, AgentMessage, List[AgentMessage]], + session_id=0, + **kwargs, + ) -> AgentMessage: # message.receiver = self.name - message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] + message = [ + AgentMessage(sender='user', content=m) + if isinstance(m, str) else copy.deepcopy(m) for m in message + ] for hook in self._hooks.values(): result = hook.before_agent(self, message, session_id) if result: message = result self.update_memory(message, session_id=session_id) - response_message = self.forward(*message, session_id=session_id, **kwargs) + response_message = self.forward( + *message, session_id=session_id, **kwargs) if not isinstance(response_message, AgentMessage): - response_message = AgentMessage(sender=self.name, content=response_message) + response_message = AgentMessage( + sender=self.name, + content=response_message, + ) self.update_memory(response_message, session_id=session_id) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): @@ -82,14 +94,25 @@ class Agent: response_message = result return response_message - def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: + def forward(self, + *message: AgentMessage, + session_id=0, + **kwargs) -> Union[AgentMessage, str]: formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), self.name, self.output_format, self.template + self.memory.get(session_id), + self.name, + self.output_format, + self.template, ) llm_response = self.llm.chat(formatted_messages, **kwargs) if self.output_format: - formatted_messages = self.output_format.parse_response(llm_response) - return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages) + formatted_messages = self.output_format.parse_response( + llm_response) + return AgentMessage( + sender=self.name, + content=llm_response, + formatted=formatted_messages, + ) return llm_response def __setattr__(self, __name: str, __value: Any) -> None: @@ -142,8 +165,12 @@ class Agent: self._hooks[handle.id] = hook return handle - def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = False): - assert not (keypath and recursive), 'keypath and recursive can\'t be used together' + def reset(self, + session_id=0, + keypath: Optional[str] = None, + recursive: bool = False): + assert not (keypath and + recursive), 'keypath and recursive can\'t be used together' if keypath: keys, agent = keypath.split('.'), self for key in keys: @@ -162,13 +189,15 @@ class Agent: def __repr__(self): def _rcsv_repr(agent, n_indent=1): - res = agent.__class__.__name__ + (f"(name='{agent.name}')" if agent.name else '') + res = agent.__class__.__name__ + (f"(name='{agent.name}')" + if agent.name else '') modules = [ f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}" for name, agent in getattr(agent, '_agents', {}).items() ] if modules: - res += '(\n' + '\n'.join(modules) + f'\n{(n_indent - 1) * " "})' + res += '(\n' + '\n'.join( + modules) + f'\n{(n_indent - 1) * " "})' elif not res.endswith(')'): res += '()' return res @@ -176,18 +205,28 @@ class Agent: return _rcsv_repr(self) -class AsyncAgentMixin: +class AsyncAgent(Agent): - async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: - message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] + async def __call__(self, + *message: AgentMessage | List[AgentMessage], + session_id=0, + **kwargs) -> AgentMessage: + message = [ + AgentMessage(sender='user', content=m) + if isinstance(m, str) else copy.deepcopy(m) for m in message + ] for hook in self._hooks.values(): result = hook.before_agent(self, message, session_id) if result: message = result self.update_memory(message, session_id=session_id) - response_message = await self.forward(*message, session_id=session_id, **kwargs) + response_message = await self.forward( + *message, session_id=session_id, **kwargs) if not isinstance(response_message, AgentMessage): - response_message = AgentMessage(sender=self.name, content=response_message) + response_message = AgentMessage( + sender=self.name, + content=response_message, + ) self.update_memory(response_message, session_id=session_id) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): @@ -196,133 +235,40 @@ class AsyncAgentMixin: response_message = result return response_message - async def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: + async def forward(self, + *message: AgentMessage, + session_id=0, + **kwargs) -> Union[AgentMessage, str]: formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), self.name, self.output_format, self.template + self.memory.get(session_id), + self.name, + self.output_format, + self.template, ) - llm_response = await self.llm.chat(formatted_messages, session_id, **kwargs) + llm_response = await self.llm.chat(formatted_messages, session_id, + **kwargs) if self.output_format: - formatted_messages = self.output_format.parse_response(llm_response) - return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages) - return llm_response - - -class AsyncAgent(AsyncAgentMixin, Agent): - """Asynchronous variant of the Agent class""" - - pass - - -class StreamingAgentMixin: - """Component that makes agent calling output a streaming response.""" - - def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Generator[AgentMessage, None, None]: - message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] - for hook in self._hooks.values(): - result = hook.before_agent(self, message, session_id) - if result: - message = result - self.update_memory(message, session_id=session_id) - response_message = AgentMessage(sender=self.name, content="") - for response_message in self.forward(*message, session_id=session_id, **kwargs): - if not isinstance(response_message, AgentMessage): - model_state, response = response_message - response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state) - yield response_message.model_copy() - self.update_memory(response_message, session_id=session_id) - response_message = copy.deepcopy(response_message) - for hook in self._hooks.values(): - result = hook.after_agent(self, response_message, session_id) - if result: - response_message = result - yield response_message - - def forward( - self, *message: AgentMessage, session_id=0, **kwargs - ) -> Generator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None, None]: - formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), self.name, self.output_format, self.template - ) - for model_state, response, *_ in self.llm.stream_chat(formatted_messages, session_id=session_id, **kwargs): - yield ( - AgentMessage( - sender=self.name, - content=response, - formatted=self.output_format.parse_response(response), - stream_state=model_state, - ) - if self.output_format - else (model_state, response) - ) - - -class AsyncStreamingAgentMixin: - """Component that makes asynchronous agent calling output a streaming response.""" - - async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AsyncGenerator[AgentMessage, None]: - message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] - for hook in self._hooks.values(): - result = hook.before_agent(self, message, session_id) - if result: - message = result - self.update_memory(message, session_id=session_id) - response_message = AgentMessage(sender=self.name, content="") - async for response_message in self.forward(*message, session_id=session_id, **kwargs): - if not isinstance(response_message, AgentMessage): - model_state, response = response_message - response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state) - yield response_message.model_copy() - self.update_memory(response_message, session_id=session_id) - response_message = copy.deepcopy(response_message) - for hook in self._hooks.values(): - result = hook.after_agent(self, response_message, session_id) - if result: - response_message = result - yield response_message - - async def forward( - self, *message: AgentMessage, session_id=0, **kwargs - ) -> AsyncGenerator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None]: - formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), self.name, self.output_format, self.template - ) - async for model_state, response, *_ in self.llm.stream_chat( - formatted_messages, session_id=session_id, **kwargs - ): - yield ( - AgentMessage( - sender=self.name, - content=response, - formatted=self.output_format.parse_response(response), - stream_state=model_state, - ) - if self.output_format - else (model_state, response) + formatted_messages = self.output_format.parse_response( + llm_response) + return AgentMessage( + sender=self.name, + content=llm_response, + formatted=formatted_messages, ) - - -class StreamingAgent(StreamingAgentMixin, Agent): - """Streaming variant of the Agent class""" - - pass - - -class AsyncStreamingAgent(AsyncStreamingAgentMixin, Agent): - """Streaming variant of the AsyncAgent class""" - - pass + return llm_response class Sequential(Agent): - """Sequential is an agent container that forwards messages to each agent + """Sequential is an agent container that forwards messages to each agent in the order they are added.""" - def __init__(self, *agents: Union[Agent, Iterable], **kwargs): + def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs): super().__init__(**kwargs) self._agents = OrderedDict() if not agents: raise ValueError('At least one agent should be provided') - if isinstance(agents[0], Iterable) and not isinstance(agents[0], Agent): + if isinstance(agents[0], + Iterable) and not isinstance(agents[0], Agent): if not agents[0]: raise ValueError('At least one agent should be provided') agents = agents[0] @@ -333,11 +279,17 @@ class Sequential(Agent): key, agent = agent self.add_agent(key, agent) - def add_agent(self, name: str, agent: Agent): - assert isinstance(agent, Agent), f'{type(agent)} is not an Agent subclass' + def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]): + assert isinstance( + agent, (Agent, AsyncAgent + )), f'{type(agent)} is not an Agent or AsyncAgent subclass' self._agents[str(name)] = agent - def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs) -> AgentMessage: + def forward(self, + *message: AgentMessage, + session_id=0, + exit_at: Optional[int] = None, + **kwargs) -> AgentMessage: assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' if exit_at is None: exit_at = len(self) - 1 @@ -345,7 +297,7 @@ class Sequential(Agent): for _ in range(exit_at + 1): agent = next(iterator) if isinstance(message, AgentMessage): - message = (message,) + message = (message, ) message = agent(*message, session_id=session_id, **kwargs) return message @@ -359,11 +311,13 @@ class Sequential(Agent): return len(self._agents) -class AsyncSequential(AsyncAgentMixin, Sequential): +class AsyncSequential(Sequential, AsyncAgent): - async def forward( - self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs - ) -> AgentMessage: + async def forward(self, + *message: AgentMessage, + session_id=0, + exit_at: Optional[int] = None, + **kwargs) -> AgentMessage: assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' if exit_at is None: exit_at = len(self) - 1 @@ -371,43 +325,11 @@ class AsyncSequential(AsyncAgentMixin, Sequential): for _ in range(exit_at + 1): agent = next(iterator) if isinstance(message, AgentMessage): - message = (message,) + message = (message, ) message = await agent(*message, session_id=session_id, **kwargs) return message -class StreamingSequential(StreamingAgentMixin, Sequential): - """Streaming variant of the Sequential class""" - - def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs): - assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' - if exit_at is None: - exit_at = len(self) - 1 - iterator = chain.from_iterable(repeat(self._agents.values())) - for _ in range(exit_at + 1): - agent = next(iterator) - if isinstance(message, AgentMessage): - message = (message,) - for message in agent(*message, session_id=session_id, **kwargs): - yield message - - -class AsyncStreamingSequential(AsyncStreamingAgentMixin, Sequential): - """Streaming variant of the AsyncSequential class""" - - async def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs): - assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' - if exit_at is None: - exit_at = len(self) - 1 - iterator = chain.from_iterable(repeat(self._agents.values())) - for _ in range(exit_at + 1): - agent = next(iterator) - if isinstance(message, AgentMessage): - message = (message,) - async for message in agent(*message, session_id=session_id, **kwargs): - yield message - - class AgentContainerMixin: def __init_subclass__(cls): @@ -427,28 +349,33 @@ class AgentContainerMixin: ret = func(self, *args, **kwargs) agents = OrderedDict() - for k, item in self.data.items() if isinstance(self.data, abc.Mapping) else enumerate(self.data): - if isinstance(self.data, abc.Mapping) and not isinstance(k, str): + for k, item in (self.data.items() if isinstance( + self.data, abc.Mapping) else enumerate(self.data)): + if isinstance(self.data, + abc.Mapping) and not isinstance(k, str): _backup(data) - raise KeyError(f'agent name should be a string, got {type(k)}') + raise KeyError( + f'agent name should be a string, got {type(k)}') if isinstance(k, str) and '.' in k: _backup(data) - raise KeyError(f'agent name can\'t contain ".", got {k}') - if not isinstance(item, Agent): + raise KeyError( + f'agent name can\'t contain ".", got {k}') + if not isinstance(item, (Agent, AsyncAgent)): _backup(data) - raise TypeError(f'{type(item)} is not an Agent subclass') + raise TypeError( + f'{type(item)} is not an Agent or AsyncAgent subclass' + ) agents[str(k)] = item self._agents = agents return ret return wrapped_func - # fmt: off for method in [ - 'append', 'sort', 'reverse', 'pop', 'clear', 'update', - 'insert', 'extend', 'remove', '__init__', '__setitem__', - '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', - '__imul__', '__rmul__' + 'append', 'sort', 'reverse', 'pop', 'clear', 'update', + 'insert', 'extend', 'remove', '__init__', '__setitem__', + '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', + '__imul__', '__rmul__' ]: if hasattr(cls, method): setattr(cls, method, wrap_api(getattr(cls, method))) @@ -456,7 +383,8 @@ class AgentContainerMixin: class AgentList(Agent, UserList, AgentContainerMixin): - def __init__(self, agents: Optional[Iterable[Agent]] = None): + def __init__(self, + agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None): Agent.__init__(self, memory=None) UserList.__init__(self, agents) self.name = None @@ -464,7 +392,9 @@ class AgentList(Agent, UserList, AgentContainerMixin): class AgentDict(Agent, UserDict, AgentContainerMixin): - def __init__(self, agents: Optional[Mapping[str, Agent]] = None): + def __init__(self, + agents: Optional[Mapping[str, Union[Agent, + AsyncAgent]]] = None): Agent.__init__(self, memory=None) UserDict.__init__(self, agents) self.name = None diff --git a/lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc b/lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc4d82716cd03fcde4145e3683c52f8b3698931a Binary files /dev/null and b/lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc b/lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f929b7139bd57eeb4575a77ddff494203cb25e35 Binary files /dev/null and b/lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc differ diff --git a/lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc b/lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..214e104a2b3d4abe2ea0a92b0944ad4eedb1ae32 Binary files /dev/null and b/lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc differ diff --git a/lagent/agents/react.py b/lagent/agents/react.py index 4a942a068c44e8481ec8f382b63947407826305b..41d2414d0f1d15066aba5f56cae9afd9c9140c7c 100644 --- a/lagent/agents/react.py +++ b/lagent/agents/react.py @@ -12,6 +12,7 @@ from lagent.memory import Memory from lagent.prompts.parsers.json_parser import JSONParser from lagent.prompts.prompt_template import PromptTemplate from lagent.schema import AgentMessage +from lagent.utils import create_object select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括: {action_info} @@ -27,88 +28,96 @@ output_format_template = """如果使用工具请遵循以下格式回复: class ReAct(Agent): - def __init__( - self, - llm: Union[BaseLLM, Dict], - actions: Union[BaseAction, List[BaseAction]], - template: Union[PromptTemplate, str] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict(type=JSONParser), - aggregator: Dict = dict(type=DefaultAggregator), - hooks: List = [dict(type=ActionPreprocessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content - or 'conclusion' in m.formatted, - max_turn: int = 5, - **kwargs - ): + def __init__(self, + llm: Union[BaseLLM, Dict], + actions: Union[BaseAction, List[BaseAction]], + template: Union[PromptTemplate, str] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict(type=JSONParser), + aggregator: Dict = dict(type=DefaultAggregator), + hooks: List = [dict(type=ActionPreprocessor)], + finish_condition: Callable[[AgentMessage], bool] = lambda m: + 'conclusion' in m.content or 'conclusion' in m.formatted, + max_turn: int = 5, + **kwargs): self.max_turn = max_turn self.finish_condition = finish_condition - self.actions = ActionExecutor(actions=actions, hooks=hooks) - self.select_agent = Agent( + actions = dict( + type=ActionExecutor, + actions=actions, + hooks=hooks, + ) + self.actions: ActionExecutor = create_object(actions) + select_agent = dict( + type=Agent, llm=llm, template=template.format( - action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction() - ), + action_info=json.dumps(self.actions.description()), + output_format=output_format.format_instruction()), output_format=output_format, memory=memory, aggregator=aggregator, hooks=hooks, ) + self.select_agent = create_object(select_agent) super().__init__(**kwargs) - def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: + def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: for _ in range(self.max_turn): - message = self.select_agent(message, session_id=session_id, **kwargs) + message = self.select_agent(message) if self.finish_condition(message): return message - message = self.actions(message, session_id=session_id) + message = self.actions(message) return message class AsyncReAct(AsyncAgent): - def __init__( - self, - llm: Union[BaseLLM, Dict], - actions: Union[BaseAction, List[BaseAction]], - template: Union[PromptTemplate, str] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict(type=JSONParser), - aggregator: Dict = dict(type=DefaultAggregator), - hooks: List = [dict(type=ActionPreprocessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content - or 'conclusion' in m.formatted, - max_turn: int = 5, - **kwargs - ): + def __init__(self, + llm: Union[BaseLLM, Dict], + actions: Union[BaseAction, List[BaseAction]], + template: Union[PromptTemplate, str] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict(type=JSONParser), + aggregator: Dict = dict(type=DefaultAggregator), + hooks: List = [dict(type=ActionPreprocessor)], + finish_condition: Callable[[AgentMessage], bool] = lambda m: + 'conclusion' in m.content or 'conclusion' in m.formatted, + max_turn: int = 5, + **kwargs): self.max_turn = max_turn self.finish_condition = finish_condition - self.actions = AsyncActionExecutor(actions=actions, hooks=hooks) - self.select_agent = AsyncAgent( + actions = dict( + type=AsyncActionExecutor, + actions=actions, + hooks=hooks, + ) + self.actions: AsyncActionExecutor = create_object(actions) + select_agent = dict( + type=AsyncAgent, llm=llm, template=template.format( - action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction() - ), + action_info=json.dumps(self.actions.description()), + output_format=output_format.format_instruction()), output_format=output_format, memory=memory, aggregator=aggregator, hooks=hooks, ) + self.select_agent = create_object(select_agent) super().__init__(**kwargs) - async def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: + async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: for _ in range(self.max_turn): - message = await self.select_agent(message, session_id=session_id, **kwargs) + message = await self.select_agent(message) if self.finish_condition(message): return message - message = await self.actions(message, session_id=session_id) + message = await self.actions(message) return message if __name__ == '__main__': - import asyncio - - from lagent.llms import GPTAPI, AsyncGPTAPI + from lagent.llms import GPTAPI class ActionCall(BaseModel): name: str = Field(description='调用的函数名称') @@ -116,49 +125,37 @@ if __name__ == '__main__': class ActionFormat(BaseModel): thought_process: str = Field( - description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。' - ) + description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。') class FinishFormat(BaseModel): thought_process: str = Field( - description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。' - ) + description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') conclusion: str = Field(description='总结当前的搜索结果,回答问题。') prompt_template = PromptTemplate(select_action_template) - output_format = JSONParser(output_format_template, function_format=ActionFormat, finish_format=FinishFormat) + output_format = JSONParser( + output_format_template, + function_format=ActionFormat, + finish_format=FinishFormat) + + llm = dict( + type=GPTAPI, + model_type='gpt-4o-2024-05-13', + key=None, + max_new_tokens=4096, + proxies=dict(), + retry=1000) agent = ReAct( - llm=dict( - type=GPTAPI, - model_type='gpt-4o-2024-05-13', - max_new_tokens=4096, - proxies=dict(), - retry=1000, - ), + llm=llm, template=prompt_template, output_format=output_format, - aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'), - actions=[dict(type='lagent.actions.PythonInterpreter')], + aggregator=dict(type='DefaultAggregator'), + actions=[dict(type='PythonInterpreter')], ) - response = agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) + response = agent( + AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) print(response) response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢')) print(response) - - async_agent = AsyncReAct( - llm=dict( - type=AsyncGPTAPI, - model_type='gpt-4o-2024-05-13', - max_new_tokens=4096, - proxies=dict(), - retry=1000, - ), - template=prompt_template, - output_format=output_format, - aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'), - actions=[dict(type='lagent.actions.AsyncPythonInterpreter')], - ) - response = asyncio.run(async_agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))) - print(async_agent.state_dict()) diff --git a/lagent/agents/stream.py b/lagent/agents/stream.py index 9687aac1e17b8b89930e10ec6d59df66ce3bf04b..512250ff02c7dd3f09dd844e999e343b597feab8 100644 --- a/lagent/agents/stream.py +++ b/lagent/agents/stream.py @@ -15,30 +15,25 @@ from lagent.utils import create_object API_PREFIX = ( "This is the subfunction for tool '{tool_name}', you can use this tool. " - 'The description of this function is: \n{description}' -) + 'The description of this function is: \n{description}') -META_CN = '当开启工具以及代码时,根据需求选择合适的工具进行调用' +META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用') -INTERPRETER_CN = ( - '你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' - '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' - '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' - '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' - '文本处理和分析(比如文本解析和自然语言处理),' - '机器学习和数据科学(用于展示模型训练和数据可视化),' - '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。' -) +INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' + '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' + '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' + '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' + '文本处理和分析(比如文本解析和自然语言处理),' + '机器学习和数据科学(用于展示模型训练和数据可视化),' + '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。') -PLUGIN_CN = ( - '你可以使用如下工具:' - '\n{prompt}\n' - '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' - '同时注意你可以使用的工具,不要随意捏造!' -) +PLUGIN_CN = ('你可以使用如下工具:' + '\n{prompt}\n' + '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' + '同时注意你可以使用的工具,不要随意捏造!') -def get_plugin_prompt(actions, api_desc_template='{description}'): +def get_plugin_prompt(actions, api_desc_template=API_PREFIX): plugin_descriptions = [] for action in actions if isinstance(actions, list) else [actions]: action = create_object(action) @@ -46,9 +41,20 @@ def get_plugin_prompt(actions, api_desc_template='{description}'): if action.is_toolkit: for api in action_desc['api_list']: api['name'] = f"{action.name}.{api['name']}" - api['description'] = api_desc_template.format(tool_name=action.name, description=api['description']) + api['description'] = api_desc_template.format( + tool_name=action.name, description=api['description']) + api['parameters'] = [ + param for param in api['parameters'] + if param['name'] in api['required'] + ] plugin_descriptions.append(api) else: + action_desc['description'] = api_desc_template.format( + tool_name=action.name, description=action_desc['description']) + action_desc['parameters'] = [ + param for param in action_desc['parameters'] + if param['name'] in action_desc['required'] + ] plugin_descriptions.append(action_desc) return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4) @@ -70,15 +76,17 @@ class AgentForInternLM(Agent): parsers=[ dict(type=PluginParser, template=PLUGIN_CN), dict(type=InterpreterParser, template=INTERPRETER_CN), - ], - ), + ]), aggregator: Dict = dict(type=InternLMToolAggregator), action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + finish_condition: Callable[ + [AgentMessage], + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 4, **kwargs, ): - self.agent = self._INTERNAL_AGENT_CLS( + agent = dict( + type=self._INTERNAL_AGENT_CLS, llm=llm, template=template, output_format=output_format, @@ -86,18 +94,22 @@ class AgentForInternLM(Agent): aggregator=aggregator, hooks=kwargs.pop('hooks', None), ) - self.plugin_executor = plugins and ActionExecutor(plugins, hooks=action_hooks) - self.interpreter_executor = interpreter and ActionExecutor(interpreter, hooks=action_hooks) + self.agent = create_object(agent) + self.plugin_executor = plugins and ActionExecutor( + plugins, hooks=action_hooks) + self.interpreter_executor = interpreter and ActionExecutor( + interpreter, hooks=action_hooks) if not (self.plugin_executor or self.interpreter_executor): warnings.warn( 'Neither plugin nor interpreter executor is initialized. ' - 'An exception will be thrown when the agent call a tool.' - ) + 'An exception will be thrown when the agent call a tool.') self.finish_condition = finish_condition self.max_turn = max_turn super().__init__(**kwargs) def forward(self, message: AgentMessage, session_id=0, **kwargs): + if isinstance(message, str): + message = AgentMessage(sender='user', content=message) for _ in range(self.max_turn): message = self.agent(message, session_id=session_id, **kwargs) assert isinstance(message.formatted, dict) @@ -115,10 +127,15 @@ class AgentForInternLM(Agent): steps, tool_type = [], None for msg in self.agent.memory.get_memory(session_id): if msg.sender == self.agent.name: - steps.append(dict(role='thought', content=msg.formatted['thought'])) + steps.append( + dict(role='thought', content=msg.formatted['thought'])) if msg.formatted['tool_type']: tool_type = msg.formatted['tool_type'] - steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type)) + steps.append( + dict( + role='tool', + content=msg.formatted['action'], + name=tool_type)) elif msg.sender != 'user': feedback = dict(role='environment', content=msg.content) if tool_type: @@ -132,22 +149,23 @@ class MathCoder(AgentForInternLM): def __init__( self, llm: Union[BaseLLM, Dict], - interpreter: dict = dict(type=IPythonInteractive, timeout=20, max_out_len=8192), + interpreter: dict = dict( + type=IPythonInteractive, timeout=20, max_out_len=8192), template: Union[str, dict, List[dict]] = None, memory: Dict = dict(type=Memory), output_format: Dict = dict( type=InterpreterParser, - template=( - 'Integrate step-by-step reasoning and Python code to solve math problems ' - 'using the following guidelines:\n' - '- Analyze the question and write jupyter code to solve the problem;\n' - r"- Present the final result in LaTeX using a '\boxed{{}}' without any " - 'units. \n' - ), - ), + template= + ('Integrate step-by-step reasoning and Python code to solve math problems ' + 'using the following guidelines:\n' + '- Analyze the question and write jupyter code to solve the problem;\n' + r"- Present the final result in LaTeX using a '\boxed{{}}' without any " + 'units. \n')), aggregator: Dict = dict(type=InternLMToolAggregator), action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + finish_condition: Callable[ + [AgentMessage], + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 6, **kwargs, ): @@ -162,8 +180,7 @@ class MathCoder(AgentForInternLM): action_hooks=action_hooks, finish_condition=finish_condition, max_turn=max_turn, - **kwargs, - ) + **kwargs) class AsyncAgentForInternLM(AsyncAgent): @@ -183,15 +200,17 @@ class AsyncAgentForInternLM(AsyncAgent): parsers=[ dict(type=PluginParser, template=PLUGIN_CN), dict(type=InterpreterParser, template=INTERPRETER_CN), - ], - ), + ]), aggregator: Dict = dict(type=InternLMToolAggregator), action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + finish_condition: Callable[ + [AgentMessage], + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 4, **kwargs, ): - self.agent = self._INTERNAL_AGENT_CLS( + agent = dict( + type=self._INTERNAL_AGENT_CLS, llm=llm, template=template, output_format=output_format, @@ -199,20 +218,25 @@ class AsyncAgentForInternLM(AsyncAgent): aggregator=aggregator, hooks=kwargs.pop('hooks', None), ) - self.plugin_executor = plugins and AsyncActionExecutor(plugins, hooks=action_hooks) - self.interpreter_executor = interpreter and AsyncActionExecutor(interpreter, hooks=action_hooks) + self.agent = create_object(agent) + self.plugin_executor = plugins and AsyncActionExecutor( + plugins, hooks=action_hooks) + self.interpreter_executor = interpreter and AsyncActionExecutor( + interpreter, hooks=action_hooks) if not (self.plugin_executor or self.interpreter_executor): warnings.warn( 'Neither plugin nor interpreter executor is initialized. ' - 'An exception will be thrown when the agent call a tool.' - ) + 'An exception will be thrown when the agent call a tool.') self.finish_condition = finish_condition self.max_turn = max_turn super().__init__(**kwargs) async def forward(self, message: AgentMessage, session_id=0, **kwargs): + if isinstance(message, str): + message = AgentMessage(sender='user', content=message) for _ in range(self.max_turn): - message = await self.agent(message, session_id=session_id, **kwargs) + message = await self.agent( + message, session_id=session_id, **kwargs) assert isinstance(message.formatted, dict) if self.finish_condition(message): return message @@ -228,10 +252,15 @@ class AsyncAgentForInternLM(AsyncAgent): steps, tool_type = [], None for msg in self.agent.memory.get_memory(session_id): if msg.sender == self.agent.name: - steps.append(dict(role='thought', content=msg.formatted['thought'])) + steps.append( + dict(role='thought', content=msg.formatted['thought'])) if msg.formatted['tool_type']: tool_type = msg.formatted['tool_type'] - steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type)) + steps.append( + dict( + role='tool', + content=msg.formatted['action'], + name=tool_type)) elif msg.sender != 'user': feedback = dict(role='environment', content=msg.content) if tool_type: @@ -250,17 +279,17 @@ class AsyncMathCoder(AsyncAgentForInternLM): memory: Dict = dict(type=Memory), output_format: Dict = dict( type=InterpreterParser, - template=( - 'Integrate step-by-step reasoning and Python code to solve math problems ' - 'using the following guidelines:\n' - '- Analyze the question and write jupyter code to solve the problem;\n' - r"- Present the final result in LaTeX using a '\boxed{{}}' without any " - 'units. \n' - ), - ), + template= + ('Integrate step-by-step reasoning and Python code to solve math problems ' + 'using the following guidelines:\n' + '- Analyze the question and write jupyter code to solve the problem;\n' + r"- Present the final result in LaTeX using a '\boxed{{}}' without any " + 'units. \n')), aggregator: Dict = dict(type=InternLMToolAggregator), action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + finish_condition: Callable[ + [AgentMessage], + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 6, **kwargs, ): @@ -275,13 +304,13 @@ class AsyncMathCoder(AsyncAgentForInternLM): action_hooks=action_hooks, finish_condition=finish_condition, max_turn=max_turn, - **kwargs, - ) + **kwargs) async def forward(self, message: AgentMessage, session_id=0, **kwargs): try: return await super().forward(message, session_id, **kwargs) finally: - interpreter = next(iter(self.interpreter_executor.actions.values())) + interpreter = next( + iter(self.interpreter_executor.actions.values())) if interpreter.name == 'AsyncIPythonInterpreter': await interpreter.close_session(session_id) diff --git a/lagent/distributed/http_serve/api_server.py b/lagent/distributed/http_serve/api_server.py index e78dcacc15162fcdb77e4182f3ea3b9dd69577fe..0cb6907ab4ce1bc3b73dbc74898b29cbd3f8c6f8 100644 --- a/lagent/distributed/http_serve/api_server.py +++ b/lagent/distributed/http_serve/api_server.py @@ -3,7 +3,6 @@ import os import subprocess import sys import time -import threading import aiohttp import requests @@ -78,21 +77,14 @@ class HTTPAgentServer(HTTPAgentClient): stderr=subprocess.STDOUT, text=True) - self.service_started = False - - def log_output(stream): - if stream is not None: - for line in iter(stream.readline, ''): - print(line, end='') - if 'Uvicorn running on' in line: - self.service_started = True - - # Start log output thread - threading.Thread(target=log_output, args=(self.process.stdout,), daemon=True).start() - threading.Thread(target=log_output, args=(self.process.stderr,), daemon=True).start() - - # Waiting for the service to start - while not self.service_started: + while True: + output = self.process.stdout.readline() + if not output: # 如果读到 EOF,跳出循环 + break + sys.stdout.write(output) # 打印到标准输出 + sys.stdout.flush() + if 'Uvicorn running on' in output: # 根据实际输出调整 + break time.sleep(0.1) def shutdown(self): diff --git a/lagent/hooks/__pycache__/__init__.cpython-310.pyc b/lagent/hooks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cd72c4de2b7709a904a6f7d31ba86002e480626 Binary files /dev/null and b/lagent/hooks/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc b/lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0435f9ac14cfe432a3e43bf8b55fd635877c4cb1 Binary files /dev/null and b/lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc differ diff --git a/lagent/hooks/__pycache__/hook.cpython-310.pyc b/lagent/hooks/__pycache__/hook.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..648264eee7404045b06e7681f32cc978546269b3 Binary files /dev/null and b/lagent/hooks/__pycache__/hook.cpython-310.pyc differ diff --git a/lagent/hooks/__pycache__/logger.cpython-310.pyc b/lagent/hooks/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c3274dc7d5bbfe893cfe889f46c6124b32cfc46 Binary files /dev/null and b/lagent/hooks/__pycache__/logger.cpython-310.pyc differ diff --git a/lagent/hooks/logger.py b/lagent/hooks/logger.py index ccdb80124aa1f28180d461e6340c800e6c7b38fd..50224e432a6ca1177f2f39ff760fd5855fcf43d9 100644 --- a/lagent/hooks/logger.py +++ b/lagent/hooks/logger.py @@ -1,4 +1,5 @@ import random +from typing import Optional from termcolor import COLORS, colored @@ -7,10 +8,10 @@ from .hook import Hook class MessageLogger(Hook): - def __init__(self, name: str = 'lagent', add_file_handler: bool = False): + + def __init__(self, name: str = 'lagent'): self.logger = get_logger( - name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s', add_file_handler=add_file_handler - ) + name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s') self.sender2color = {} def before_agent(self, agent, messages, session_id): @@ -28,5 +29,9 @@ class MessageLogger(Hook): def _process_message(self, message, session_id): sender = message.sender - color = self.sender2color.setdefault(sender, random.choice(list(COLORS))) - self.logger.info(colored(f'session id: {session_id}, message sender: {sender}\n' f'{message.content}', color)) + color = self.sender2color.setdefault(sender, + random.choice(list(COLORS))) + self.logger.info( + colored( + f'session id: {session_id}, message sender: {sender}\n' + f'{message.content}', color)) diff --git a/lagent/llms/__init__.py b/lagent/llms/__init__.py index 95679b156755edb26d390e5373146d59edb1318a..fcbbd07d4622b1bf53a9b0daebb6a1c35a6a1711 100644 --- a/lagent/llms/__init__.py +++ b/lagent/llms/__init__.py @@ -1,15 +1,9 @@ -from .anthropic_llm import AsyncClaudeAPI, ClaudeAPI from .base_api import AsyncBaseAPILLM, BaseAPILLM from .base_llm import AsyncBaseLLM, BaseLLM from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat -from .lmdeploy_wrapper import ( - AsyncLMDeployClient, - AsyncLMDeployPipeline, - AsyncLMDeployServer, - LMDeployClient, - LMDeployPipeline, - LMDeployServer, -) +from .lmdeploy_wrapper import (AsyncLMDeployClient, AsyncLMDeployPipeline, + AsyncLMDeployServer, LMDeployClient, + LMDeployPipeline, LMDeployServer) from .meta_template import INTERNLM2_META from .openai import GPTAPI, AsyncGPTAPI from .sensenova import SensenovaAPI @@ -35,6 +29,4 @@ __all__ = [ 'VllmModel', 'AsyncVllmModel', 'SensenovaAPI', - 'AsyncClaudeAPI', - 'ClaudeAPI', ] diff --git a/lagent/llms/__pycache__/__init__.cpython-310.pyc b/lagent/llms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63b8e8d84764c08338d57ab9423777f651f43012 Binary files /dev/null and b/lagent/llms/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/base_api.cpython-310.pyc b/lagent/llms/__pycache__/base_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..285237a83b816e84e7c6d32dd5260f69b04b6466 Binary files /dev/null and b/lagent/llms/__pycache__/base_api.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/base_llm.cpython-310.pyc b/lagent/llms/__pycache__/base_llm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db365ddf234cb13f58f47f3abaade6fd1224a575 Binary files /dev/null and b/lagent/llms/__pycache__/base_llm.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/huggingface.cpython-310.pyc b/lagent/llms/__pycache__/huggingface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9905b2a9c9e4d6fcfea6a83fbecdb7f3ca05ac94 Binary files /dev/null and b/lagent/llms/__pycache__/huggingface.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/lmdeploy_wrapper.cpython-310.pyc b/lagent/llms/__pycache__/lmdeploy_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..878ba0db770b0659f022239f7d97b03c7424e4a4 Binary files /dev/null and b/lagent/llms/__pycache__/lmdeploy_wrapper.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/meta_template.cpython-310.pyc b/lagent/llms/__pycache__/meta_template.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72f2d575403e79512135bef99846f5dfff52e2e9 Binary files /dev/null and b/lagent/llms/__pycache__/meta_template.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/openai.cpython-310.pyc b/lagent/llms/__pycache__/openai.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b68546f71f03890dafac7b049a7c51097bf0d48 Binary files /dev/null and b/lagent/llms/__pycache__/openai.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/sensenova.cpython-310.pyc b/lagent/llms/__pycache__/sensenova.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e239d68deadeea26340507fc078a0250e1bb6807 Binary files /dev/null and b/lagent/llms/__pycache__/sensenova.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/vllm_wrapper.cpython-310.pyc b/lagent/llms/__pycache__/vllm_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a9c4608e3604ad44cb2e47ee736e5de342df5c Binary files /dev/null and b/lagent/llms/__pycache__/vllm_wrapper.cpython-310.pyc differ diff --git a/lagent/llms/anthropic_llm.py b/lagent/llms/anthropic_llm.py deleted file mode 100644 index 8aac28cf6621d325db80eb8f5e7c9c07ef9af777..0000000000000000000000000000000000000000 --- a/lagent/llms/anthropic_llm.py +++ /dev/null @@ -1,399 +0,0 @@ -import asyncio -import json -import os -from typing import Dict, List, Optional, Union - -import anthropic -import httpcore -import httpx -from anthropic import NOT_GIVEN -from requests.exceptions import ProxyError - -from .base_api import AsyncBaseAPILLM, BaseAPILLM - - -class ClaudeAPI(BaseAPILLM): - - is_api: bool = True - - def __init__( - self, - model_type: str = 'claude-3-5-sonnet-20241022', - retry: int = 5, - key: Union[str, List[str]] = 'ENV', - proxies: Optional[Dict] = None, - meta_template: Optional[Dict] = [ - dict(role='system', api_role='system'), - dict(role='user', api_role='user'), - dict(role='assistant', api_role='assistant'), - dict(role='environment', api_role='user'), - ], - temperature: float = NOT_GIVEN, - max_new_tokens: int = 512, - top_p: float = NOT_GIVEN, - top_k: int = NOT_GIVEN, - repetition_penalty: float = 0.0, - stop_words: Union[List[str], str] = None, - ): - - super().__init__( - meta_template=meta_template, - model_type=model_type, - retry=retry, - temperature=temperature, - max_new_tokens=max_new_tokens, - top_p=top_p, - top_k=top_k, - repetition_penalty=repetition_penalty, - stop_words=stop_words, - ) - - key = os.getenv('Claude_API_KEY') if key == 'ENV' else key - - if isinstance(key, str): - self.keys = [key] - else: - self.keys = list(set(key)) - self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys} - - # record invalid keys and skip them when requesting API - # - keys have insufficient_quota - self.invalid_keys = set() - - self.key_ctr = 0 - self.model_type = model_type - self.proxies = proxies - - def chat( - self, - inputs: Union[List[dict], List[List[dict]]], - session_ids: Union[int, List[int]] = None, - **gen_params, - ) -> Union[str, List[str]]: - """Generate responses given the contexts. - - Args: - inputs (Union[List[dict], List[List[dict]]]): a list of messages - or list of lists of messages - gen_params: additional generation configuration - - Returns: - Union[str, List[str]]: generated string(s) - """ - assert isinstance(inputs, list) - gen_params = {**self.gen_params, **gen_params} - import nest_asyncio - - nest_asyncio.apply() - - async def run_async_tasks(): - tasks = [ - self._chat(self.template_parser(messages), **gen_params) - for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) - ] - return await asyncio.gather(*tasks) - - try: - loop = asyncio.get_running_loop() - # If the event loop is already running, schedule the task - future = asyncio.ensure_future(run_async_tasks()) - ret = loop.run_until_complete(future) - except RuntimeError: - # If no running event loop, start a new one - ret = asyncio.run(run_async_tasks()) - return ret[0] if isinstance(inputs[0], dict) else ret - - def generate_request_data(self, model_type, messages, gen_params): - """ - Generates the request data for different model types. - - Args: - model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). - messages (list): The list of messages to be sent to the model. - gen_params (dict): The generation parameters. - json_mode (bool): Flag to determine if the response format should be JSON. - - Returns: - tuple: A tuple containing the header and the request data. - """ - # Copy generation parameters to avoid modifying the original dictionary - gen_params = gen_params.copy() - - # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) - if max_tokens <= 0: - return '', '' - gen_params.pop('repetition_penalty') - if 'stop_words' in gen_params: - gen_params['stop_sequences'] = gen_params.pop('stop_words') - # Common parameters processing - gen_params['max_tokens'] = max_tokens - gen_params.pop('skip_special_tokens', None) - gen_params.pop('session_id', None) - - system = None - if messages[0]['role'] == 'system': - system = messages.pop(0) - system = system['content'] - for message in messages: - message.pop('name', None) - data = {'model': model_type, 'messages': messages, **gen_params} - if system: - data['system'] = system - return data - - async def _chat(self, messages: List[dict], **gen_params) -> str: - """Generate completion from a list of templates. - - Args: - messages (List[dict]): a list of prompt dictionaries - gen_params: additional generation configuration - - Returns: - str: The generated string. - """ - assert isinstance(messages, list) - - data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) - max_num_retries = 0 - - while max_num_retries < self.retry: - if len(self.invalid_keys) == len(self.keys): - raise RuntimeError('All keys have insufficient quota.') - # find the next valid key - while True: - self.key_ctr += 1 - if self.key_ctr == len(self.keys): - self.key_ctr = 0 - - if self.keys[self.key_ctr] not in self.invalid_keys: - break - - key = self.keys[self.key_ctr] - client = self.clients[key] - - try: - response = await client.messages.create(**data) - response = json.loads(response.json()) - return response['content'][0]['text'].strip() - except (anthropic.RateLimitError, anthropic.APIConnectionError) as e: - print(f'API请求错误: {e}') - await asyncio.sleep(5) - - except (httpcore.ProxyError, ProxyError) as e: - - print(f'代理服务器错误: {e}') - await asyncio.sleep(5) - except httpx.TimeoutException as e: - print(f'请求超时: {e}') - await asyncio.sleep(5) - - except KeyboardInterrupt: - raise - - except Exception as error: - if error.body['error']['message'] == 'invalid x-api-key': - self.invalid_keys.add(key) - self.logger.warn(f'invalid key: {key}') - elif error.body['error']['type'] == 'overloaded_error': - await asyncio.sleep(5) - elif error.body['error']['message'] == 'Internal server error': - await asyncio.sleep(5) - elif error.body['error']['message'] == ( - 'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to ' - 'upgrade or purchase credits.' - ): - self.invalid_keys.add(key) - print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') - max_num_retries += 1 - - raise RuntimeError( - 'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' - ) - - -class AsyncClaudeAPI(AsyncBaseAPILLM): - - is_api: bool = True - - def __init__( - self, - model_type: str = 'claude-3-5-sonnet-20241022', - retry: int = 5, - key: Union[str, List[str]] = 'ENV', - proxies: Optional[Dict] = None, - meta_template: Optional[Dict] = [ - dict(role='system', api_role='system'), - dict(role='user', api_role='user'), - dict(role='assistant', api_role='assistant'), - dict(role='environment', api_role='user'), - ], - temperature: float = NOT_GIVEN, - max_new_tokens: int = 512, - top_p: float = NOT_GIVEN, - top_k: int = NOT_GIVEN, - repetition_penalty: float = 0.0, - stop_words: Union[List[str], str] = None, - ): - - super().__init__( - model_type=model_type, - retry=retry, - meta_template=meta_template, - temperature=temperature, - max_new_tokens=max_new_tokens, - top_p=top_p, - top_k=top_k, - repetition_penalty=repetition_penalty, - stop_words=stop_words, - ) - - key = os.getenv('Claude_API_KEY') if key == 'ENV' else key - - if isinstance(key, str): - self.keys = [key] - else: - self.keys = list(set(key)) - self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys} - - # record invalid keys and skip them when requesting API - # - keys have insufficient_quota - self.invalid_keys = set() - - self.key_ctr = 0 - self.model_type = model_type - self.proxies = proxies - - async def chat( - self, - inputs: Union[List[dict], List[List[dict]]], - session_ids: Union[int, List[int]] = None, - **gen_params, - ) -> Union[str, List[str]]: - """Generate responses given the contexts. - - Args: - inputs (Union[List[dict], List[List[dict]]]): a list of messages - or list of lists of messages - gen_params: additional generation configuration - - Returns: - Union[str, List[str]]: generated string(s) - """ - assert isinstance(inputs, list) - gen_params = {**self.gen_params, **gen_params} - tasks = [ - self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) - ] - ret = await asyncio.gather(*tasks) - return ret[0] if isinstance(inputs[0], dict) else ret - - def generate_request_data(self, model_type, messages, gen_params): - """ - Generates the request data for different model types. - - Args: - model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). - messages (list): The list of messages to be sent to the model. - gen_params (dict): The generation parameters. - json_mode (bool): Flag to determine if the response format should be JSON. - - Returns: - tuple: A tuple containing the header and the request data. - """ - # Copy generation parameters to avoid modifying the original dictionary - gen_params = gen_params.copy() - - # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) - if max_tokens <= 0: - return '', '' - gen_params.pop('repetition_penalty') - if 'stop_words' in gen_params: - gen_params['stop_sequences'] = gen_params.pop('stop_words') - # Common parameters processing - gen_params['max_tokens'] = max_tokens - gen_params.pop('skip_special_tokens', None) - gen_params.pop('session_id', None) - - system = None - if messages[0]['role'] == 'system': - system = messages.pop(0) - system = system['content'] - for message in messages: - message.pop('name', None) - data = {'model': model_type, 'messages': messages, **gen_params} - if system: - data['system'] = system - return data - - async def _chat(self, messages: List[dict], **gen_params) -> str: - """Generate completion from a list of templates. - - Args: - messages (List[dict]): a list of prompt dictionaries - gen_params: additional generation configuration - - Returns: - str: The generated string. - """ - assert isinstance(messages, list) - messages = self.template_parser(messages) - data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) - max_num_retries = 0 - - while max_num_retries < self.retry: - if len(self.invalid_keys) == len(self.keys): - raise RuntimeError('All keys have insufficient quota.') - # find the next valid key - while True: - self.key_ctr += 1 - if self.key_ctr == len(self.keys): - self.key_ctr = 0 - - if self.keys[self.key_ctr] not in self.invalid_keys: - break - - key = self.keys[self.key_ctr] - client = self.clients[key] - - try: - response = await client.messages.create(**data) - response = json.loads(response.json()) - return response['content'][0]['text'].strip() - except (anthropic.RateLimitError, anthropic.APIConnectionError) as e: - print(f'API请求错误: {e}') - await asyncio.sleep(5) - - except (httpcore.ProxyError, ProxyError) as e: - - print(f'代理服务器错误: {e}') - await asyncio.sleep(5) - except httpx.TimeoutException as e: - print(f'请求超时: {e}') - await asyncio.sleep(5) - - except KeyboardInterrupt: - raise - - except Exception as error: - if error.body['error']['message'] == 'invalid x-api-key': - self.invalid_keys.add(key) - self.logger.warn(f'invalid key: {key}') - elif error.body['error']['type'] == 'overloaded_error': - await asyncio.sleep(5) - elif error.body['error']['message'] == 'Internal server error': - await asyncio.sleep(5) - elif error.body['error']['message'] == ( - 'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to' - ' upgrade or purchase credits.' - ): - self.invalid_keys.add(key) - print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') - else: - raise error - max_num_retries += 1 - - raise RuntimeError( - 'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' - ) diff --git a/lagent/llms/lmdeploy_wrapper.py b/lagent/llms/lmdeploy_wrapper.py index b5813e79ae7c51f605da9b89927623059f3c16dd..283d50e8de090878d9e15c1d78d60ee8818fdc8a 100644 --- a/lagent/llms/lmdeploy_wrapper.py +++ b/lagent/llms/lmdeploy_wrapper.py @@ -556,19 +556,7 @@ class AsyncLMDeployPipeline(AsyncLLMMixin, LMDeployPipeline): assert len(inputs) == len(session_ids) prompt = inputs - do_sample = kwargs.pop('do_sample', None) gen_params = self.update_gen_params(**kwargs) - if do_sample is None: - do_sample = self.do_sample - if do_sample is not None and self.version < (0, 6, 0): - raise RuntimeError( - '`do_sample` parameter is not supported by lmdeploy until ' - f'v0.6.0, but currently using lmdeloy {self.str_version}') - if self.version >= (0, 6, 0): - if do_sample is None: - do_sample = gen_params['top_k'] > 1 or gen_params[ - 'temperature'] > 0 - gen_params.update(do_sample=do_sample) gen_config = GenerationConfig( skip_special_tokens=skip_special_tokens, **gen_params) diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 7418a65c32ed5d6d1f312b36bb1caec825ec8b69..ffbd1b3de10bb6799c673784367acb476fe495cf 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -47,27 +47,30 @@ class GPTAPI(BaseAPILLM): is_api: bool = True - def __init__( - self, - model_type: str = 'gpt-3.5-turbo', - retry: int = 2, - json_mode: bool = False, - key: Union[str, List[str]] = 'ENV', - org: Optional[Union[str, List[str]]] = None, - meta_template: Optional[Dict] = [ - dict(role='system', api_role='system'), - dict(role='user', api_role='user'), - dict(role='assistant', api_role='assistant'), - dict(role='environment', api_role='system'), - ], - api_base: str = OPENAI_API_BASE, - proxies: Optional[Dict] = None, - **gen_params, - ): + def __init__(self, + model_type: str = 'gpt-3.5-turbo', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='system') + ], + api_base: str = OPENAI_API_BASE, + proxies: Optional[Dict] = None, + **gen_params): if 'top_k' in gen_params: - warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) gen_params.pop('top_k') - super().__init__(model_type=model_type, meta_template=meta_template, retry=retry, **gen_params) + super().__init__( + model_type=model_type, + meta_template=meta_template, + retry=retry, + **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) @@ -112,8 +115,11 @@ class GPTAPI(BaseAPILLM): gen_params = {**self.gen_params, **gen_params} with ThreadPoolExecutor(max_workers=20) as executor: tasks = [ - executor.submit(self._chat, messages, **gen_params) - for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) + executor.submit(self._chat, + self.template_parser._prompt2api(messages), + **gen_params) + for messages in ( + [inputs] if isinstance(inputs[0], dict) else inputs) ] ret = [task.result() for task in tasks] return ret[0] if isinstance(inputs[0], dict) else ret @@ -144,9 +150,12 @@ class GPTAPI(BaseAPILLM): if stop_words is None: stop_words = [] # mapping to role that openai supports - messages = self.template_parser(inputs) + messages = self.template_parser._prompt2api(inputs) for text in self._stream_chat(messages, **gen_params): - resp += text + if self.model_type.lower().startswith('qwen'): + resp = text + else: + resp += text if not resp: continue # remove stop_words @@ -171,10 +180,12 @@ class GPTAPI(BaseAPILLM): str: The generated string. """ assert isinstance(messages, list) - messages = self.template_parser(messages) + header, data = self.generate_request_data( - model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode - ) + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) max_num_retries, errmsg = 0, '' while max_num_retries < self.retry: @@ -203,7 +214,11 @@ class GPTAPI(BaseAPILLM): response = dict() try: - raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies) + raw_response = requests.post( + self.url, + headers=header, + data=json.dumps(data), + proxies=self.proxies) response = raw_response.json() return response['choices'][0]['message']['content'].strip() except requests.ConnectionError: @@ -224,18 +239,17 @@ class GPTAPI(BaseAPILLM): self.logger.warn(f'insufficient_quota key: {key}') continue - errmsg = 'Find error message in response: ' + str(response['error']) + errmsg = 'Find error message in response: ' + str( + response['error']) self.logger.error(errmsg) except Exception as error: errmsg = str(error) + '\n' + str(traceback.format_exc()) self.logger.error(errmsg) max_num_retries += 1 - raise RuntimeError( - 'Calling OpenAI failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - f'details. errmsg: {errmsg}' - ) + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') def _stream_chat(self, messages: List[dict], **gen_params) -> str: """Generate completion from a list of templates. @@ -249,7 +263,8 @@ class GPTAPI(BaseAPILLM): """ def streaming(raw_response): - for chunk in raw_response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\n'): + for chunk in raw_response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b'\n'): if chunk: decoded = chunk.decode('utf-8') if decoded.startswith('data: [DONE]'): @@ -267,11 +282,16 @@ class GPTAPI(BaseAPILLM): # Context exceeds maximum length yield '' return - - choice = response['choices'][0] - if choice['finish_reason'] == 'stop': - return - yield choice['delta'].get('content', '') + if self.model_type.lower().startswith('qwen'): + choice = response['output']['choices'][0] + yield choice['message']['content'] + if choice['finish_reason'] == 'stop': + return + else: + choice = response['choices'][0] + if choice['finish_reason'] == 'stop': + return + yield choice['delta'].get('content', '') except Exception as exc: msg = f'response {decoded} lead to exception of {str(exc)}' self.logger.error(msg) @@ -280,8 +300,10 @@ class GPTAPI(BaseAPILLM): assert isinstance(messages, list) header, data = self.generate_request_data( - model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode - ) + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) max_num_retries, errmsg = 0, '' while max_num_retries < self.retry: @@ -308,7 +330,11 @@ class GPTAPI(BaseAPILLM): response = dict() try: - raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies, stream=True) + raw_response = requests.post( + self.url, + headers=header, + data=json.dumps(data), + proxies=self.proxies) return streaming(raw_response) except requests.ConnectionError: errmsg = 'Got connection error ' + str(traceback.format_exc()) @@ -328,20 +354,23 @@ class GPTAPI(BaseAPILLM): self.logger.warn(f'insufficient_quota key: {key}') continue - errmsg = 'Find error message in response: ' + str(response['error']) + errmsg = 'Find error message in response: ' + str( + response['error']) self.logger.error(errmsg) except Exception as error: errmsg = str(error) + '\n' + str(traceback.format_exc()) self.logger.error(errmsg) max_num_retries += 1 - raise RuntimeError( - 'Calling OpenAI failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - f'details. errmsg: {errmsg}' - ) + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') - def generate_request_data(self, model_type, messages, gen_params, json_mode=False): + def generate_request_data(self, + model_type, + messages, + gen_params, + json_mode=False): """ Generates the request data for different model types. @@ -372,25 +401,56 @@ class GPTAPI(BaseAPILLM): if 'stop_words' in gen_params: gen_params['stop'] = gen_params.pop('stop_words') if 'repetition_penalty' in gen_params: - gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty') + gen_params['frequency_penalty'] = gen_params.pop( + 'repetition_penalty') # Model-specific processing data = {} - if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'): + if model_type.lower().startswith('gpt'): if 'top_k' in gen_params: - warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) + warnings.warn( + '`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) gen_params.pop('top_k') gen_params.pop('skip_special_tokens', None) gen_params.pop('session_id', None) - data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } if json_mode: data['response_format'] = {'type': 'json_object'} elif model_type.lower().startswith('internlm'): - data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } if json_mode: data['response_format'] = {'type': 'json_object'} + elif model_type.lower().startswith('qwen'): + header['X-DashScope-SSE'] = 'enable' + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + if 'frequency_penalty' in gen_params: + gen_params['repetition_penalty'] = gen_params.pop( + 'frequency_penalty') + gen_params['result_format'] = 'message' + data = { + 'model': model_type, + 'input': { + 'messages': messages + }, + 'parameters': { + **gen_params + } + } else: - raise NotImplementedError(f'Model type {model_type} is not supported') + raise NotImplementedError( + f'Model type {model_type} is not supported') return header, data @@ -404,7 +464,6 @@ class GPTAPI(BaseAPILLM): list: token ids """ import tiktoken - self.tiktoken = tiktoken enc = self.tiktoken.encoding_for_model(self.model_type) return enc.encode(prompt) @@ -436,27 +495,29 @@ class AsyncGPTAPI(AsyncBaseAPILLM): is_api: bool = True - def __init__( - self, - model_type: str = 'gpt-3.5-turbo', - retry: int = 2, - json_mode: bool = False, - key: Union[str, List[str]] = 'ENV', - org: Optional[Union[str, List[str]]] = None, - meta_template: Optional[Dict] = [ - dict(role='system', api_role='system'), - dict(role='user', api_role='user'), - dict(role='assistant', api_role='assistant'), - dict(role='environment', api_role='system'), - ], - api_base: str = OPENAI_API_BASE, - proxies: Optional[Dict] = None, - **gen_params, - ): + def __init__(self, + model_type: str = 'gpt-3.5-turbo', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant') + ], + api_base: str = OPENAI_API_BASE, + proxies: Optional[Dict] = None, + **gen_params): if 'top_k' in gen_params: - warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) gen_params.pop('top_k') - super().__init__(model_type=model_type, meta_template=meta_template, retry=retry, **gen_params) + super().__init__( + model_type=model_type, + meta_template=meta_template, + retry=retry, + **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) @@ -501,7 +562,8 @@ class AsyncGPTAPI(AsyncBaseAPILLM): raise NotImplementedError('unsupported parameter: max_tokens') gen_params = {**self.gen_params, **gen_params} tasks = [ - self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) + self._chat(messages, **gen_params) for messages in ( + [inputs] if isinstance(inputs[0], dict) else inputs) ] ret = await asyncio.gather(*tasks) return ret[0] if isinstance(inputs[0], dict) else ret @@ -532,9 +594,12 @@ class AsyncGPTAPI(AsyncBaseAPILLM): if stop_words is None: stop_words = [] # mapping to role that openai supports - messages = self.template_parser(inputs) + messages = self.template_parser._prompt2api(inputs) async for text in self._stream_chat(messages, **gen_params): - resp += text + if self.model_type.lower().startswith('qwen'): + resp = text + else: + resp += text if not resp: continue # remove stop_words @@ -559,10 +624,12 @@ class AsyncGPTAPI(AsyncBaseAPILLM): str: The generated string. """ assert isinstance(messages, list) - messages = self.template_parser(messages) + header, data = self.generate_request_data( - model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode - ) + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) max_num_retries, errmsg = 0, '' while max_num_retries < self.retry: @@ -591,10 +658,14 @@ class AsyncGPTAPI(AsyncBaseAPILLM): try: async with aiohttp.ClientSession() as session: async with session.post( - self.url, headers=header, json=data, proxy=self.proxies.get('https', self.proxies.get('http')) - ) as resp: + self.url, + headers=header, + json=data, + proxy=self.proxies.get( + 'https', self.proxies.get('http'))) as resp: response = await resp.json() - return response['choices'][0]['message']['content'].strip() + return response['choices'][0]['message'][ + 'content'].strip() except aiohttp.ClientConnectionError: errmsg = 'Got connection error ' + str(traceback.format_exc()) self.logger.error(errmsg) @@ -604,7 +675,8 @@ class AsyncGPTAPI(AsyncBaseAPILLM): self.logger.error(errmsg) continue except json.JSONDecodeError: - errmsg = 'JsonDecode error, got ' + (await resp.text(errors='replace')) + errmsg = 'JsonDecode error, got ' + (await resp.text( + errors='replace')) self.logger.error(errmsg) continue except KeyError: @@ -617,20 +689,20 @@ class AsyncGPTAPI(AsyncBaseAPILLM): self.logger.warn(f'insufficient_quota key: {key}') continue - errmsg = 'Find error message in response: ' + str(response['error']) + errmsg = 'Find error message in response: ' + str( + response['error']) self.logger.error(errmsg) except Exception as error: errmsg = str(error) + '\n' + str(traceback.format_exc()) self.logger.error(errmsg) max_num_retries += 1 - raise RuntimeError( - 'Calling OpenAI failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - f'details. errmsg: {errmsg}' - ) + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') - async def _stream_chat(self, messages: List[dict], **gen_params) -> AsyncGenerator[str, None]: + async def _stream_chat(self, messages: List[dict], + **gen_params) -> AsyncGenerator[str, None]: """Generate completion from a list of templates. Args: @@ -660,11 +732,16 @@ class AsyncGPTAPI(AsyncBaseAPILLM): # Context exceeds maximum length yield '' return - - choice = response['choices'][0] - if choice['finish_reason'] == 'stop': - return - yield choice['delta'].get('content', '') + if self.model_type.lower().startswith('qwen'): + choice = response['output']['choices'][0] + yield choice['message']['content'] + if choice['finish_reason'] == 'stop': + return + else: + choice = response['choices'][0] + if choice['finish_reason'] == 'stop': + return + yield choice['delta'].get('content', '') except Exception as exc: msg = f'response {decoded} lead to exception of {str(exc)}' self.logger.error(msg) @@ -673,8 +750,10 @@ class AsyncGPTAPI(AsyncBaseAPILLM): assert isinstance(messages, list) header, data = self.generate_request_data( - model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode - ) + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) max_num_retries, errmsg = 0, '' while max_num_retries < self.retry: @@ -703,8 +782,12 @@ class AsyncGPTAPI(AsyncBaseAPILLM): try: async with aiohttp.ClientSession() as session: async with session.post( - self.url, headers=header, json=data, proxy=self.proxies.get('https', self.proxies.get('http')) - ) as raw_response: + self.url, + headers=header, + json=data, + proxy=self.proxies.get( + 'https', + self.proxies.get('http'))) as raw_response: async for msg in streaming(raw_response): yield msg return @@ -726,20 +809,23 @@ class AsyncGPTAPI(AsyncBaseAPILLM): self.logger.warn(f'insufficient_quota key: {key}') continue - errmsg = 'Find error message in response: ' + str(response['error']) + errmsg = 'Find error message in response: ' + str( + response['error']) self.logger.error(errmsg) except Exception as error: errmsg = str(error) + '\n' + str(traceback.format_exc()) self.logger.error(errmsg) max_num_retries += 1 - raise RuntimeError( - 'Calling OpenAI failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - f'details. errmsg: {errmsg}' - ) + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') - def generate_request_data(self, model_type, messages, gen_params, json_mode=False): + def generate_request_data(self, + model_type, + messages, + gen_params, + json_mode=False): """ Generates the request data for different model types. @@ -770,28 +856,56 @@ class AsyncGPTAPI(AsyncBaseAPILLM): if 'stop_words' in gen_params: gen_params['stop'] = gen_params.pop('stop_words') if 'repetition_penalty' in gen_params: - gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty') + gen_params['frequency_penalty'] = gen_params.pop( + 'repetition_penalty') # Model-specific processing data = {} - if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'): + if model_type.lower().startswith('gpt'): if 'top_k' in gen_params: - warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) + warnings.warn( + '`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) gen_params.pop('top_k') gen_params.pop('skip_special_tokens', None) gen_params.pop('session_id', None) - - data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } if json_mode: data['response_format'] = {'type': 'json_object'} elif model_type.lower().startswith('internlm'): - data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } if json_mode: data['response_format'] = {'type': 'json_object'} - elif model_type.lower().startswith('o1'): - data = {'model': model_type, 'messages': messages, 'n': 1} + elif model_type.lower().startswith('qwen'): + header['X-DashScope-SSE'] = 'enable' + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + if 'frequency_penalty' in gen_params: + gen_params['repetition_penalty'] = gen_params.pop( + 'frequency_penalty') + gen_params['result_format'] = 'message' + data = { + 'model': model_type, + 'input': { + 'messages': messages + }, + 'parameters': { + **gen_params + } + } else: - raise NotImplementedError(f'Model type {model_type} is not supported') + raise NotImplementedError( + f'Model type {model_type} is not supported') return header, data @@ -805,7 +919,6 @@ class AsyncGPTAPI(AsyncBaseAPILLM): list: token ids """ import tiktoken - self.tiktoken = tiktoken enc = self.tiktoken.encoding_for_model(self.model_type) return enc.encode(prompt) diff --git a/lagent/memory/__pycache__/__init__.cpython-310.pyc b/lagent/memory/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed060de87d4f933b6ba8404afbef0133ff625bf2 Binary files /dev/null and b/lagent/memory/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/memory/__pycache__/base_memory.cpython-310.pyc b/lagent/memory/__pycache__/base_memory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..582735df64b09613ba48a6147944eb6c7e9ec2a3 Binary files /dev/null and b/lagent/memory/__pycache__/base_memory.cpython-310.pyc differ diff --git a/lagent/memory/__pycache__/manager.cpython-310.pyc b/lagent/memory/__pycache__/manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52ca7ea7ece7ee4d4aac95231199d5fcfd65b29a Binary files /dev/null and b/lagent/memory/__pycache__/manager.cpython-310.pyc differ diff --git a/lagent/prompts/__pycache__/__init__.cpython-310.pyc b/lagent/prompts/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..becc3a4365248b489109a7c37d2db7d1898ea7c6 Binary files /dev/null and b/lagent/prompts/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/prompts/__pycache__/prompt_template.cpython-310.pyc b/lagent/prompts/__pycache__/prompt_template.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dec4e790bbd75a03793938bd6d2a85021dc0cfd Binary files /dev/null and b/lagent/prompts/__pycache__/prompt_template.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/__init__.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50724aa226f6e8c14720e92edfce4616867b133c Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/custom_parser.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/custom_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75e77ce468006e2388960f33e480e349bb94f127 Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/custom_parser.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/json_parser.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/json_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07043fe6edc32b3c28b389076a3e5ac033acecf0 Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/json_parser.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/str_parser.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/str_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f95283dd78cc45d386d5c06b441b5531b0ddf554 Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/str_parser.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/tool_parser.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/tool_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac1a6572f5dae5d528d044be380673c4bb021c0d Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/tool_parser.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/str_parser.py b/lagent/prompts/parsers/str_parser.py index be997bc8376270e53769549657e1b32b51ff1a19..6af7aa6ecafacf8ff13ad0f3f2cbcdd41d65b735 100644 --- a/lagent/prompts/parsers/str_parser.py +++ b/lagent/prompts/parsers/str_parser.py @@ -1,4 +1,3 @@ -import string from typing import Any @@ -9,17 +8,14 @@ class StrParser: template: str = '', **format_field, ): - fields = {item[1] for item in string.Formatter().parse(template) if item[1] is not None} - if not fields.issubset(format_field.keys()): - raise ValueError( - 'not all required fields of "template" are provided, missing ' - f'{fields - format_field.keys()}. Please pass them as keyword arguments.' - ) self.template = template self.format_field = format_field def format_instruction(self) -> Any: - format_data = {key: self.format_to_string(value) for key, value in self.format_field.items()} + format_data = { + key: self.format_to_string(value) + for key, value in self.format_field.items() + } return self.template.format(**format_data) def format_to_string(self, format_model: Any) -> str: diff --git a/lagent/prompts/parsers/tool_parser.py b/lagent/prompts/parsers/tool_parser.py index a8ffea3fc6348caf665e3e48788431d3673c1d5b..534331275b71d1b443aba81bec710eae0deba88d 100644 --- a/lagent/prompts/parsers/tool_parser.py +++ b/lagent/prompts/parsers/tool_parser.py @@ -23,24 +23,29 @@ class ToolStatusCode(IntEnum): class ToolParser(StrParser): - def __init__( - self, - tool_type: str, - template: str = '', - begin: str = '\n', - end: str = '\n', - validate: Callable[[str], Any] = None, - **kwargs - ): + def __init__(self, + tool_type: str, + template: str = '', + begin: str = '\n', + end: str = '\n', + validate: Callable[[str], Any] = None, + **kwargs): super().__init__(template, begin=begin, end=end, **kwargs) self.template = template self.tool_type = tool_type - # self.pattern = re.compile('(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)), re.DOTALL) - self.validate = load_class_from_string(validate) if isinstance(validate, str) else validate + # self.pattern = re.compile( + # '(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)), + # re.DOTALL) + self.validate = load_class_from_string(validate) if isinstance( + validate, str) else validate def parse_response(self, data: str) -> dict: if self.format_field['begin'] not in data: - return dict(tool_type=None, thought=data, action=None, status=ToolStatusCode.NO_TOOL) + return dict( + tool_type=None, + thought=data, + action=None, + status=ToolStatusCode.NO_TOOL) thought, action, *_ = data.split(self.format_field["begin"]) action = action.split(self.format_field['end'])[0] status = ToolStatusCode.VALID_TOOL @@ -49,7 +54,11 @@ class ToolParser(StrParser): action = self.validate(action) except Exception: status = ToolStatusCode.PARSING_ERROR - return dict(tool_type=self.tool_type, thought=thought, action=action, status=status) + return dict( + tool_type=self.tool_type, + thought=thought, + action=action, + status=status) def format_response(self, parsed: dict) -> str: if parsed['action'] is None: @@ -59,40 +68,41 @@ class ToolParser(StrParser): action = json.dumps(parsed['action'], ensure_ascii=False) else: action = str(parsed['action']) - return parsed['thought'] + self.format_field['begin'] + action + self.format_field['end'] + return parsed['thought'] + self.format_field[ + 'begin'] + action + self.format_field['end'] class InterpreterParser(ToolParser): - def __init__( - self, - tool_type: str = 'interpreter', - template: str = '', - begin: str = '<|action_start|><|interpreter|>\n', - end: str = '<|action_end|>\n', - validate: Callable[[str], Any] = None, - **kwargs - ): + def __init__(self, + tool_type: str = 'interpreter', + template: str = '', + begin: str = '<|action_start|><|interpreter|>\n', + end: str = '<|action_end|>\n', + validate: Callable[[str], Any] = None, + **kwargs): super().__init__(tool_type, template, begin, end, validate, **kwargs) class PluginParser(ToolParser): - def __init__( - self, - tool_type: str = 'plugin', - template: str = '', - begin: str = '<|action_start|><|plugin|>\n', - end: str = '<|action_end|>\n', - validate: Callable[[str], Any] = default_plugin_validate, - **kwargs - ): + def __init__(self, + tool_type: str = 'plugin', + template: str = '', + begin: str = '<|action_start|><|plugin|>\n', + end: str = '<|action_end|>\n', + validate: Callable[[str], Any] = default_plugin_validate, + **kwargs): super().__init__(tool_type, template, begin, end, validate, **kwargs) class MixedToolParser(StrParser): - def __init__(self, tool_type: Optional[str] = None, template='', parsers: List[ToolParser] = None, **format_field): + def __init__(self, + tool_type: Optional[str] = None, + template='', + parsers: List[ToolParser] = None, + **format_field): self.parsers = {} self.tool_type = tool_type for parser in parsers or []: @@ -115,7 +125,11 @@ class MixedToolParser(StrParser): return inst def parse_response(self, data: str) -> dict: - res = dict(tool_type=None, thought=data, action=None, status=ToolStatusCode.NO_TOOL) + res = dict( + tool_type=None, + thought=data, + action=None, + status=ToolStatusCode.NO_TOOL) for name, parser in self.parsers.items(): res = parser.parse_response(data) if res['tool_type'] == name: diff --git a/lagent/utils/__pycache__/__init__.cpython-310.pyc b/lagent/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39db06cffc13a1f7708f2a6cfebfaf969d8fc463 Binary files /dev/null and b/lagent/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/utils/__pycache__/package.cpython-310.pyc b/lagent/utils/__pycache__/package.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc7301604387d50978b70a6ca8126fc0427e1f03 Binary files /dev/null and b/lagent/utils/__pycache__/package.cpython-310.pyc differ diff --git a/lagent/utils/__pycache__/util.cpython-310.pyc b/lagent/utils/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e4118919946e99581d3c6ff620a0e0f3d2b67f5 Binary files /dev/null and b/lagent/utils/__pycache__/util.cpython-310.pyc differ diff --git a/lagent/utils/util.py b/lagent/utils/util.py index 609382ecfceb59587ff8f640c5b2de71213b6711..a40482b53c1c94a1a6f65d64a80f5282b29bd4c2 100644 --- a/lagent/utils/util.py +++ b/lagent/utils/util.py @@ -29,8 +29,8 @@ def load_class_from_string(class_path: str, path=None): def create_object(config: Union[Dict, Any] = None): - """Create an instance based on the configuration where 'type' is a - preserved key to indicate the class (path). When accepting non-dictionary + """Create an instance based on the configuration where 'type' is a + preserved key to indicate the class (path). When accepting non-dictionary input, the function degenerates to an identity. """ if config is None or not isinstance(config, dict): @@ -62,7 +62,8 @@ async def async_as_completed(futures: Iterable[asyncio.Future]): yield await next_completed -def filter_suffix(response: Union[str, List[str]], suffixes: Optional[List[str]] = None) -> str: +def filter_suffix(response: Union[str, List[str]], + suffixes: Optional[List[str]] = None) -> str: """Filter response with suffixes. Args: @@ -94,11 +95,12 @@ def filter_suffix(response: Union[str, List[str]], suffixes: Optional[List[str]] def get_logger( name: str = 'lagent', level: str = 'debug', - fmt: str = '%(asctime)s %(levelname)8s %(filename)20s %(lineno)4s - %(message)s', + fmt: + str = '%(asctime)s %(levelname)8s %(filename)20s %(lineno)4s - %(message)s', add_file_handler: bool = False, log_dir: str = 'log', - log_file: str = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + '.log', - max_bytes: int = 50 * 1024 * 1024, + log_file: str = time.strftime('%Y-%m-%d.log', time.localtime()), + max_bytes: int = 5 * 1024 * 1024, backup_count: int = 3, ): logger = logging.getLogger(name) @@ -115,8 +117,10 @@ def get_logger( os.makedirs(log_dir) log_file_path = osp.join(log_dir, log_file) file_handler = RotatingFileHandler( - log_file_path, maxBytes=max_bytes, backupCount=backup_count, encoding='utf-8' - ) + log_file_path, + maxBytes=max_bytes, + backupCount=backup_count, + encoding='utf-8') file_handler.setFormatter(formatter) logger.addHandler(file_handler) diff --git a/lagent/version.py b/lagent/version.py index 01c3552eea5c5f35158eee3452a8e5127b592eac..d9c59dd319e7e1be2581dc9330a2ed8120073173 100644 --- a/lagent/version.py +++ b/lagent/version.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -__version__ = '0.5.0rc2' +__version__ = '0.5.0rc1' def parse_version_info(version_str): diff --git a/requirements/optional.txt b/requirements/optional.txt index 0ae76dff9f6b42dba71cd4a8481b13822acc2cf9..75645dbe7bcd072298adc6660b7a139857171045 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,4 +1,3 @@ -duckduckgo_search==5.3.1b1 google-search-results lmdeploy>=0.2.5 pillow diff --git a/requirements/runtime.txt b/requirements/runtime.txt index ac0b85c7083e53a8736539ea4227702bc075269e..6fcd4ea1c8a27b67417be5c9bf4079340341c437 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,9 +1,9 @@ aiohttp -anthropic arxiv asyncache asyncer distro +duckduckgo_search==5.3.1b1 filelock func_timeout griffe<1.0 @@ -14,7 +14,6 @@ jupyter_client==8.6.2 jupyter_core==5.7.2 pydantic==2.6.4 requests -tenacity termcolor tiktoken timeout-decorator