markqiu commited on
Commit
569cdb0
·
1 Parent(s): 10bbb8e

百度文心一言的例子

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. erniebot-agent/README.md +44 -0
  2. erniebot-agent/erniebot_agent/__init__.py +19 -0
  3. erniebot-agent/erniebot_agent/agents/__init__.py +13 -0
  4. erniebot-agent/erniebot_agent/agents/base.py +279 -0
  5. erniebot-agent/erniebot_agent/agents/callback/__init__.py +13 -0
  6. erniebot-agent/erniebot_agent/agents/callback/callback_manager.py +94 -0
  7. erniebot-agent/erniebot_agent/agents/callback/default.py +22 -0
  8. erniebot-agent/erniebot_agent/agents/callback/event.py +26 -0
  9. erniebot-agent/erniebot_agent/agents/callback/handlers/__init__.py +13 -0
  10. erniebot-agent/erniebot_agent/agents/callback/handlers/base.py +55 -0
  11. erniebot-agent/erniebot_agent/agents/callback/handlers/logging_handler.py +107 -0
  12. erniebot-agent/erniebot_agent/agents/functional_agent.py +148 -0
  13. erniebot-agent/erniebot_agent/agents/schema.py +93 -0
  14. erniebot-agent/erniebot_agent/chat_models/__init__.py +17 -0
  15. erniebot-agent/erniebot_agent/chat_models/base.py +60 -0
  16. erniebot-agent/erniebot_agent/chat_models/erniebot.py +135 -0
  17. erniebot-agent/erniebot_agent/extensions/langchain/chat_models/__init__.py +1 -0
  18. erniebot-agent/erniebot_agent/extensions/langchain/chat_models/erniebot.py +356 -0
  19. erniebot-agent/erniebot_agent/extensions/langchain/embeddings/__init__.py +1 -0
  20. erniebot-agent/erniebot_agent/extensions/langchain/embeddings/ernie.py +82 -0
  21. erniebot-agent/erniebot_agent/extensions/langchain/llms/__init__.py +1 -0
  22. erniebot-agent/erniebot_agent/extensions/langchain/llms/erniebot.py +239 -0
  23. erniebot-agent/erniebot_agent/file_io/__init__.py +13 -0
  24. erniebot-agent/erniebot_agent/file_io/base.py +46 -0
  25. erniebot-agent/erniebot_agent/file_io/file_manager.py +138 -0
  26. erniebot-agent/erniebot_agent/file_io/file_registry.py +55 -0
  27. erniebot-agent/erniebot_agent/file_io/local_file.py +55 -0
  28. erniebot-agent/erniebot_agent/file_io/protocol.py +57 -0
  29. erniebot-agent/erniebot_agent/file_io/remote_file.py +153 -0
  30. erniebot-agent/erniebot_agent/memory/__init__.py +18 -0
  31. erniebot-agent/erniebot_agent/memory/base.py +99 -0
  32. erniebot-agent/erniebot_agent/memory/limit_token_memory.py +59 -0
  33. erniebot-agent/erniebot_agent/memory/sliding_window_memory.py +41 -0
  34. erniebot-agent/erniebot_agent/memory/whole_memory.py +19 -0
  35. erniebot-agent/erniebot_agent/messages.py +124 -0
  36. erniebot-agent/erniebot_agent/prompt/__init__.py +16 -0
  37. erniebot-agent/erniebot_agent/prompt/base.py +28 -0
  38. erniebot-agent/erniebot_agent/prompt/prompt_template.py +80 -0
  39. erniebot-agent/erniebot_agent/retrieval/__init__.py +0 -0
  40. erniebot-agent/erniebot_agent/retrieval/baizhong_search.py +296 -0
  41. erniebot-agent/erniebot_agent/retrieval/document.py +123 -0
  42. erniebot-agent/erniebot_agent/tools/__init__.py +15 -0
  43. erniebot-agent/erniebot_agent/tools/baizhong_tool.py +65 -0
  44. erniebot-agent/erniebot_agent/tools/base.py +428 -0
  45. erniebot-agent/erniebot_agent/tools/calculator_tool.py +66 -0
  46. erniebot-agent/erniebot_agent/tools/current_time_tool.py +47 -0
  47. erniebot-agent/erniebot_agent/tools/image_generation_tool.py +117 -0
  48. erniebot-agent/erniebot_agent/tools/schema.py +415 -0
  49. erniebot-agent/erniebot_agent/tools/tool_manager.py +69 -0
  50. erniebot-agent/erniebot_agent/utils/__init__.py +0 -0
erniebot-agent/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <h1>ERNIE Bot Agent</h1>
4
+
5
+ ERNIE Bot Agent 可以快速开发智能体。
6
+
7
+ [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
8
+ ![python version](https://img.shields.io/badge/python-3.8+-orange.svg)
9
+ ![support os](https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-yellow.svg)
10
+
11
+ </div>
12
+
13
+ `ERNIE Bot Agent` 旨在为开发者提供快速搭建大模型Agent和应用的框架。该项目还在积极研发中,敬请期待我们后续的正式发版。
14
+
15
+ ## 主要功能
16
+
17
+ ### 大模型 Agent 框架
18
+
19
+ `ERNIE Bot Agent` 将结合飞桨星河AI Studio社区,为开发者提供一站式的大模型Agent和应用搭建框架和平台。该项目还在积极研发中,敬请期待我们后续的正式发版。
20
+
21
+ ### 文心 LangChain 插件
22
+
23
+ 为了让大家更加高效、便捷地结合文心大模型与LangChain进行开发,`ERNIE Bot Agent`对`LangChain`框架进行了功能扩展,提供了基于文心大模型的大语言模型(LLM)组件、聊天模型(ChatModel)组件以及文本嵌入模型(Text Embedding Model)组件。详情请参见[使用范例Notebook](https://github.com/PaddlePaddle/ERNIE-Bot-SDK/blob/develop/erniebot-agent/examples/cookbook/how_to_use_langchain_extension.ipynb)。
24
+
25
+
26
+ ## 快速安装
27
+
28
+ 建议您可以使用pip快速安装 ERNIE Bot Agent 的最新稳定版。
29
+
30
+ ```shell
31
+ pip install --upgrade erniebot-agent
32
+ ```
33
+
34
+ 如需使用develop版本,可以下载源码后执行如下命令安装
35
+
36
+ ```shell
37
+ git clone https://github.com/PaddlePaddle/ERNIE-Bot-SDK.git
38
+ cd ERNIE-Bot-SDK/erniebot-agent
39
+ pip install .
40
+ ```
41
+
42
+ ## License
43
+
44
+ ERNIE Bot Agent遵循Apache-2.0开源协议。
erniebot-agent/erniebot_agent/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from erniebot_agent.utils.logging import logger, setup_logging
16
+
17
+ __all__ = ["logger"]
18
+
19
+ setup_logging()
erniebot-agent/erniebot_agent/agents/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
erniebot-agent/erniebot_agent/agents/base.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ import inspect
17
+ import json
18
+ from typing import Any, Dict, List, Literal, Optional, Union
19
+
20
+ from erniebot_agent.agents.callback.callback_manager import CallbackManager
21
+ from erniebot_agent.agents.callback.default import get_default_callbacks
22
+ from erniebot_agent.agents.callback.handlers.base import CallbackHandler
23
+ from erniebot_agent.agents.schema import (
24
+ AgentFile,
25
+ AgentResponse,
26
+ LLMResponse,
27
+ ToolResponse,
28
+ )
29
+ from erniebot_agent.chat_models.base import ChatModel
30
+ from erniebot_agent.file_io.file_manager import FileManager
31
+ from erniebot_agent.file_io.protocol import is_local_file_id, is_remote_file_id
32
+ from erniebot_agent.memory.base import Memory
33
+ from erniebot_agent.messages import Message, SystemMessage
34
+ from erniebot_agent.tools.base import Tool
35
+ from erniebot_agent.tools.tool_manager import ToolManager
36
+ from erniebot_agent.utils.logging import logger
37
+
38
+
39
+ class BaseAgent(metaclass=abc.ABCMeta):
40
+ llm: ChatModel
41
+ memory: Memory
42
+
43
+ @abc.abstractmethod
44
+ async def async_run(self, prompt: str) -> AgentResponse:
45
+ raise NotImplementedError
46
+
47
+
48
+ class Agent(BaseAgent):
49
+ def __init__(
50
+ self,
51
+ llm: ChatModel,
52
+ tools: Union[ToolManager, List[Tool]],
53
+ memory: Memory,
54
+ system_message: Optional[SystemMessage] = None,
55
+ *,
56
+ callbacks: Optional[Union[CallbackManager, List[CallbackHandler]]] = None,
57
+ file_manager: Optional[FileManager] = None,
58
+ ) -> None:
59
+ super().__init__()
60
+ self.llm = llm
61
+ self.memory = memory
62
+ # 1. Get system message exist in memory
63
+ # OR 2. overwrite by the system_message paased in the Agent.
64
+ if system_message:
65
+ self.system_message = system_message
66
+ else:
67
+ self.system_message = memory.get_system_message()
68
+ if isinstance(tools, ToolManager):
69
+ self._tool_manager = tools
70
+ else:
71
+ self._tool_manager = ToolManager(tools)
72
+ if callbacks is None:
73
+ callbacks = get_default_callbacks()
74
+ if isinstance(callbacks, CallbackManager):
75
+ self._callback_manager = callbacks
76
+ else:
77
+ self._callback_manager = CallbackManager(callbacks)
78
+ self.file_manager = file_manager
79
+
80
+ async def async_run(self, prompt: str) -> AgentResponse:
81
+ await self._callback_manager.on_run_start(agent=self, prompt=prompt)
82
+ agent_resp = await self._async_run(prompt)
83
+ await self._callback_manager.on_run_end(agent=self, response=agent_resp)
84
+ return agent_resp
85
+
86
+ def load_tool(self, tool: Tool) -> None:
87
+ self._tool_manager.add_tool(tool)
88
+
89
+ def unload_tool(self, tool: Tool) -> None:
90
+ self._tool_manager.remove_tool(tool)
91
+
92
+ def reset_memory(self) -> None:
93
+ self.memory.clear_chat_history()
94
+
95
+ def launch_gradio_demo(self, **launch_kwargs: Any):
96
+ # TODO: Unified optional dependencies management
97
+ try:
98
+ import gradio as gr
99
+ except ImportError:
100
+ raise ImportError(
101
+ "Could not import gradio, which is required for `launch_gradio_demo()`."
102
+ " Please run `pip install erniebot-agent[gradio]` to install the optional dependencies."
103
+ ) from None
104
+
105
+ raw_messages = []
106
+
107
+ def _pre_chat(text, history):
108
+ history.append([text, None])
109
+ return history, gr.update(value="", interactive=False), gr.update(interactive=False)
110
+
111
+ async def _chat(history):
112
+ prompt = history[-1][0]
113
+ if len(prompt) == 0:
114
+ raise gr.Error("Prompt should not be empty.")
115
+ response = await self.async_run(prompt)
116
+ history[-1][1] = response.text
117
+ raw_messages.extend(response.chat_history)
118
+ return (
119
+ history,
120
+ _messages_to_dicts(raw_messages),
121
+ _messages_to_dicts(self.memory.get_messages()),
122
+ )
123
+
124
+ def _post_chat():
125
+ return gr.update(interactive=True), gr.update(interactive=True)
126
+
127
+ def _clear():
128
+ raw_messages.clear()
129
+ self.reset_memory()
130
+ return None, None, None, None
131
+
132
+ def _messages_to_dicts(messages):
133
+ return [message.to_dict() for message in messages]
134
+
135
+ with gr.Blocks(
136
+ title="ERNIE Bot Agent Demo", theme=gr.themes.Soft(spacing_size="sm", text_size="md")
137
+ ) as demo:
138
+ with gr.Column():
139
+ chatbot = gr.Chatbot(
140
+ label="Chat history",
141
+ latex_delimiters=[
142
+ {"left": "$$", "right": "$$", "display": True},
143
+ {"left": "$", "right": "$", "display": False},
144
+ ],
145
+ bubble_full_width=False,
146
+ )
147
+ prompt_textbox = gr.Textbox(label="Prompt", placeholder="Write a prompt here...")
148
+ with gr.Row():
149
+ submit_button = gr.Button("Submit")
150
+ clear_button = gr.Button("Clear")
151
+ with gr.Accordion("Tools", open=False):
152
+ attached_tools = self._tool_manager.get_tools()
153
+ tool_descriptions = [tool.function_call_schema() for tool in attached_tools]
154
+ gr.JSON(value=tool_descriptions)
155
+ with gr.Accordion("Raw messages", open=False):
156
+ all_messages_json = gr.JSON(label="All messages")
157
+ agent_memory_json = gr.JSON(label="Messges in memory")
158
+ prompt_textbox.submit(
159
+ _pre_chat,
160
+ inputs=[prompt_textbox, chatbot],
161
+ outputs=[chatbot, prompt_textbox, submit_button],
162
+ ).then(
163
+ _chat,
164
+ inputs=[chatbot],
165
+ outputs=[
166
+ chatbot,
167
+ all_messages_json,
168
+ agent_memory_json,
169
+ ],
170
+ ).then(
171
+ _post_chat, outputs=[prompt_textbox, submit_button]
172
+ )
173
+ submit_button.click(
174
+ _pre_chat,
175
+ inputs=[prompt_textbox, chatbot],
176
+ outputs=[chatbot, prompt_textbox, submit_button],
177
+ ).then(
178
+ _chat,
179
+ inputs=[chatbot],
180
+ outputs=[
181
+ chatbot,
182
+ all_messages_json,
183
+ agent_memory_json,
184
+ ],
185
+ ).then(
186
+ _post_chat, outputs=[prompt_textbox, submit_button]
187
+ )
188
+ clear_button.click(
189
+ _clear,
190
+ outputs=[
191
+ chatbot,
192
+ prompt_textbox,
193
+ all_messages_json,
194
+ agent_memory_json,
195
+ ],
196
+ )
197
+
198
+ demo.launch(**launch_kwargs)
199
+
200
+ @abc.abstractmethod
201
+ async def _async_run(self, prompt: str) -> AgentResponse:
202
+ raise NotImplementedError
203
+
204
+ async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
205
+ tool = self._tool_manager.get_tool(tool_name)
206
+ await self._callback_manager.on_tool_start(agent=self, tool=tool, input_args=tool_args)
207
+ try:
208
+ tool_resp = await self._async_run_tool_without_hooks(tool, tool_args)
209
+ except (Exception, KeyboardInterrupt) as e:
210
+ await self._callback_manager.on_tool_error(agent=self, tool=tool, error=e)
211
+ raise
212
+ await self._callback_manager.on_tool_end(agent=self, tool=tool, response=tool_resp)
213
+ return tool_resp
214
+
215
+ async def _async_run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
216
+ await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages)
217
+ try:
218
+ llm_resp = await self._async_run_llm_without_hooks(messages, **opts)
219
+ except (Exception, KeyboardInterrupt) as e:
220
+ await self._callback_manager.on_llm_error(agent=self, llm=self.llm, error=e)
221
+ raise
222
+ await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
223
+ return llm_resp
224
+
225
+ async def _async_run_tool_without_hooks(self, tool: Tool, tool_args: str) -> ToolResponse:
226
+ bnd_args = self._parse_tool_args(tool, tool_args)
227
+ # XXX: Sniffing is less efficient and probably unnecessary.
228
+ # Can we make a protocol to statically recognize file inputs and outputs
229
+ # or can we have the tools introspect about this?
230
+ input_files = await self._sniff_and_extract_files_from_args(bnd_args.arguments, tool, "input")
231
+ tool_ret = await tool(*bnd_args.args, **bnd_args.kwargs)
232
+ output_files = await self._sniff_and_extract_files_from_args(tool_ret, tool, "output")
233
+ tool_ret_json = json.dumps(tool_ret, ensure_ascii=False)
234
+ return ToolResponse(json=tool_ret_json, files=input_files + output_files)
235
+
236
+ async def _async_run_llm_without_hooks(
237
+ self, messages: List[Message], functions=None, **opts: Any
238
+ ) -> LLMResponse:
239
+ llm_ret = await self.llm.async_chat(messages, functions=functions, stream=False, **opts)
240
+ return LLMResponse(message=llm_ret)
241
+
242
+ def _parse_tool_args(self, tool: Tool, tool_args: str) -> inspect.BoundArguments:
243
+ args_dict = json.loads(tool_args)
244
+ if not isinstance(args_dict, dict):
245
+ raise ValueError("`tool_args` cannot be interpreted as a dict.")
246
+ # TODO: Check types
247
+ sig = inspect.signature(tool.__call__)
248
+ bnd_args = sig.bind(**args_dict)
249
+ bnd_args.apply_defaults()
250
+ return bnd_args
251
+
252
+ async def _sniff_and_extract_files_from_args(
253
+ self, args: Dict[str, Any], tool: Tool, file_type: Literal["input", "output"]
254
+ ) -> List[AgentFile]:
255
+ agent_files: List[AgentFile] = []
256
+ for val in args.values():
257
+ if isinstance(val, str):
258
+ if is_local_file_id(val):
259
+ if self.file_manager is None:
260
+ logger.warning(
261
+ f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it."
262
+ )
263
+ continue
264
+ file = self.file_manager.look_up_file_by_id(val)
265
+ if file is None:
266
+ raise RuntimeError(f"Unregistered ID {repr(val)} is used by {repr(tool)}.")
267
+ elif is_remote_file_id(val):
268
+ if self.file_manager is None:
269
+ logger.warning(
270
+ f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it."
271
+ )
272
+ continue
273
+ file = self.file_manager.look_up_file_by_id(val)
274
+ if file is None:
275
+ file = await self.file_manager.retrieve_remote_file_by_id(val)
276
+ else:
277
+ continue
278
+ agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name))
279
+ return agent_files
erniebot-agent/erniebot_agent/agents/callback/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
erniebot-agent/erniebot_agent/agents/callback/callback_manager.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import inspect
18
+ from typing import TYPE_CHECKING, Any, List, Union, final
19
+
20
+ from erniebot_agent.agents.callback.event import EventType
21
+ from erniebot_agent.agents.callback.handlers.base import CallbackHandler
22
+ from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
23
+ from erniebot_agent.chat_models.base import ChatModel
24
+ from erniebot_agent.messages import Message
25
+ from erniebot_agent.tools.base import Tool
26
+
27
+ if TYPE_CHECKING:
28
+ from erniebot_agent.agents.base import Agent
29
+
30
+
31
+ @final
32
+ class CallbackManager(object):
33
+ def __init__(self, handlers: List[CallbackHandler]):
34
+ super().__init__()
35
+ self._handlers = handlers
36
+
37
+ @property
38
+ def handlers(self) -> List[CallbackHandler]:
39
+ return self._handlers
40
+
41
+ def add_handler(self, handler: CallbackHandler):
42
+ if handler in self._handlers:
43
+ raise RuntimeError(f"The callback handler {handler} is already registered.")
44
+ self._handlers.append(handler)
45
+
46
+ def remove_handler(self, handler):
47
+ try:
48
+ self._handlers.remove(handler)
49
+ except ValueError as e:
50
+ raise RuntimeError(f"The callback handler {handler} is not registered.") from e
51
+
52
+ def set_handlers(self, handlers: List[CallbackHandler]):
53
+ self._handlers = []
54
+ for handler in handlers:
55
+ self.add_handler(handler)
56
+
57
+ def remove_all_handlers(self):
58
+ self._handlers = []
59
+
60
+ async def handle_event(self, event_type: EventType, *args: Any, **kwargs: Any) -> None:
61
+ callback_name = "on_" + event_type.value
62
+ for handler in self._handlers:
63
+ callback = getattr(handler, callback_name, None)
64
+ if not inspect.iscoroutinefunction(callback):
65
+ raise TypeError("Callback must be a coroutine function.")
66
+ await callback(*args, **kwargs)
67
+
68
+ async def on_run_start(self, agent: Agent, prompt: str) -> None:
69
+ await self.handle_event(EventType.RUN_START, agent=agent, prompt=prompt)
70
+
71
+ async def on_llm_start(self, agent: Agent, llm: ChatModel, messages: List[Message]) -> None:
72
+ await self.handle_event(EventType.LLM_START, agent=agent, llm=llm, messages=messages)
73
+
74
+ async def on_llm_end(self, agent: Agent, llm: ChatModel, response: LLMResponse) -> None:
75
+ await self.handle_event(EventType.LLM_END, agent=agent, llm=llm, response=response)
76
+
77
+ async def on_llm_error(
78
+ self, agent: Agent, llm: ChatModel, error: Union[Exception, KeyboardInterrupt]
79
+ ) -> None:
80
+ await self.handle_event(EventType.LLM_ERROR, agent=agent, llm=llm, error=error)
81
+
82
+ async def on_tool_start(self, agent: Agent, tool: Tool, input_args: str) -> None:
83
+ await self.handle_event(EventType.TOOL_START, agent=agent, tool=tool, input_args=input_args)
84
+
85
+ async def on_tool_end(self, agent: Agent, tool: Tool, response: ToolResponse) -> None:
86
+ await self.handle_event(EventType.TOOL_END, agent=agent, tool=tool, response=response)
87
+
88
+ async def on_tool_error(
89
+ self, agent: Agent, tool: Tool, error: Union[Exception, KeyboardInterrupt]
90
+ ) -> None:
91
+ await self.handle_event(EventType.TOOL_ERROR, agent=agent, tool=tool, error=error)
92
+
93
+ async def on_run_end(self, agent: Agent, response: AgentResponse) -> None:
94
+ await self.handle_event(EventType.RUN_END, agent=agent, response=response)
erniebot-agent/erniebot_agent/agents/callback/default.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List
16
+
17
+ from erniebot_agent.agents.callback.handlers.base import CallbackHandler
18
+ from erniebot_agent.agents.callback.handlers.logging_handler import LoggingHandler
19
+
20
+
21
+ def get_default_callbacks() -> List[CallbackHandler]:
22
+ return [LoggingHandler()]
erniebot-agent/erniebot_agent/agents/callback/event.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import enum
16
+
17
+
18
+ class EventType(enum.Enum):
19
+ RUN_START = "run_start"
20
+ LLM_START = "llm_start"
21
+ LLM_END = "llm_end"
22
+ LLM_ERROR = "llm_error"
23
+ TOOL_START = "tool_start"
24
+ TOOL_END = "tool_end"
25
+ TOOL_ERROR = "tool_error"
26
+ RUN_END = "run_end"
erniebot-agent/erniebot_agent/agents/callback/handlers/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
erniebot-agent/erniebot_agent/agents/callback/handlers/base.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import TYPE_CHECKING, List, Union
18
+
19
+ from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
20
+ from erniebot_agent.chat_models.base import ChatModel
21
+ from erniebot_agent.messages import Message
22
+ from erniebot_agent.tools.base import Tool
23
+
24
+ if TYPE_CHECKING:
25
+ from erniebot_agent.agents.base import Agent
26
+
27
+
28
+ class CallbackHandler(object):
29
+ async def on_run_start(self, agent: Agent, prompt: str) -> None:
30
+ """"""
31
+
32
+ async def on_llm_start(self, agent: Agent, llm: ChatModel, messages: List[Message]) -> None:
33
+ """"""
34
+
35
+ async def on_llm_end(self, agent: Agent, llm: ChatModel, response: LLMResponse) -> None:
36
+ """"""
37
+
38
+ async def on_llm_error(
39
+ self, agent: Agent, llm: ChatModel, error: Union[Exception, KeyboardInterrupt]
40
+ ) -> None:
41
+ """"""
42
+
43
+ async def on_tool_start(self, agent: Agent, tool: Tool, input_args: str) -> None:
44
+ """"""
45
+
46
+ async def on_tool_end(self, agent: Agent, tool: Tool, response: ToolResponse) -> None:
47
+ """"""
48
+
49
+ async def on_tool_error(
50
+ self, agent: Agent, tool: Tool, error: Union[Exception, KeyboardInterrupt]
51
+ ) -> None:
52
+ """"""
53
+
54
+ async def on_run_end(self, agent: Agent, response: AgentResponse) -> None:
55
+ """"""
erniebot-agent/erniebot_agent/agents/callback/handlers/logging_handler.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from typing import TYPE_CHECKING, List, Optional, Union
19
+
20
+ from erniebot_agent.agents.callback.handlers.base import CallbackHandler
21
+ from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
22
+ from erniebot_agent.chat_models.base import ChatModel
23
+ from erniebot_agent.messages import Message
24
+ from erniebot_agent.tools.base import Tool
25
+ from erniebot_agent.utils.json import to_pretty_json
26
+ from erniebot_agent.utils.logging import logger as default_logger
27
+
28
+ if TYPE_CHECKING:
29
+ from erniebot_agent.agents.base import Agent
30
+
31
+
32
+ class LoggingHandler(CallbackHandler):
33
+ logger: logging.Logger
34
+
35
+ def __init__(self, logger: Optional[logging.Logger] = None) -> None:
36
+ super().__init__()
37
+ if logger is None:
38
+ self.logger = default_logger
39
+ else:
40
+ self.logger = logger
41
+
42
+ async def on_run_start(self, agent: Agent, prompt: str) -> None:
43
+ self.agent_info(
44
+ "%s is about to start running with input: %s\n",
45
+ agent.__class__.__name__,
46
+ prompt,
47
+ subject="Run",
48
+ state="Start",
49
+ )
50
+
51
+ async def on_llm_start(self, agent: Agent, llm: ChatModel, messages: List[Message]) -> None:
52
+ # TODO: Prettier messages
53
+ self.agent_info(
54
+ "%s is about to start running with input:\n%s\n",
55
+ llm.__class__.__name__,
56
+ messages,
57
+ subject="LLM",
58
+ state="Start",
59
+ )
60
+
61
+ async def on_llm_end(self, agent: Agent, llm: ChatModel, response: LLMResponse) -> None:
62
+ self.agent_info(
63
+ "%s finished running with output: %s\n",
64
+ llm.__class__.__name__,
65
+ response.message,
66
+ subject="LLM",
67
+ state="End",
68
+ )
69
+
70
+ async def on_llm_error(
71
+ self, agent: Agent, llm: ChatModel, error: Union[Exception, KeyboardInterrupt]
72
+ ) -> None:
73
+ pass
74
+
75
+ async def on_tool_start(self, agent: Agent, tool: Tool, input_args: str) -> None:
76
+ self.agent_info(
77
+ "%s is about to start running with input:\n%s\n",
78
+ tool.__class__.__name__,
79
+ to_pretty_json(input_args, from_json=True),
80
+ subject="Tool",
81
+ state="Start",
82
+ )
83
+
84
+ async def on_tool_end(self, agent: Agent, tool: Tool, response: ToolResponse) -> None:
85
+ self.agent_info(
86
+ "%s finished running with output:\n%s\n",
87
+ tool.__class__.__name__,
88
+ to_pretty_json(response.json, from_json=True),
89
+ subject="Tool",
90
+ state="End",
91
+ )
92
+
93
+ async def on_tool_error(
94
+ self, agent: Agent, tool: Tool, error: Union[Exception, KeyboardInterrupt]
95
+ ) -> None:
96
+ pass
97
+
98
+ async def on_run_end(self, agent: Agent, response: AgentResponse) -> None:
99
+ self.agent_info("%s finished running.\n", agent.__class__.__name__, subject="Run", state="End")
100
+
101
+ def agent_info(self, msg: str, *args, subject, state, **kwargs) -> None:
102
+ msg = f"[{subject}][{state}] {msg}"
103
+ self.logger.info(msg, *args, **kwargs)
104
+
105
+ def agent_error(self, error: Union[Exception, KeyboardInterrupt], *args, subject, **kwargs) -> None:
106
+ error_msg = f"[{subject}][ERROR] {error}"
107
+ self.logger.error(error_msg, *args, **kwargs)
erniebot-agent/erniebot_agent/agents/functional_agent.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Union
16
+
17
+ from erniebot_agent.agents.base import Agent, ToolManager
18
+ from erniebot_agent.agents.callback.callback_manager import CallbackManager
19
+ from erniebot_agent.agents.callback.handlers.base import CallbackHandler
20
+ from erniebot_agent.agents.schema import AgentAction, AgentFile, AgentResponse
21
+ from erniebot_agent.chat_models.base import ChatModel
22
+ from erniebot_agent.file_io.file_manager import FileManager
23
+ from erniebot_agent.memory.base import Memory
24
+ from erniebot_agent.messages import (
25
+ FunctionMessage,
26
+ HumanMessage,
27
+ Message,
28
+ SystemMessage,
29
+ )
30
+ from erniebot_agent.tools.base import Tool
31
+
32
+ _MAX_STEPS = 5
33
+
34
+
35
+ class FunctionalAgent(Agent):
36
+ def __init__(
37
+ self,
38
+ llm: ChatModel,
39
+ tools: Union[ToolManager, List[Tool]],
40
+ memory: Memory,
41
+ system_message: Optional[SystemMessage] = None,
42
+ *,
43
+ callbacks: Optional[Union[CallbackManager, List[CallbackHandler]]] = None,
44
+ file_manager: Optional[FileManager] = None,
45
+ max_steps: Optional[int] = None,
46
+ ) -> None:
47
+ super().__init__(
48
+ llm=llm,
49
+ tools=tools,
50
+ memory=memory,
51
+ system_message=system_message,
52
+ callbacks=callbacks,
53
+ file_manager=file_manager,
54
+ )
55
+ if max_steps is not None:
56
+ if max_steps <= 0:
57
+ raise ValueError("Invalid `max_steps` value")
58
+ self.max_steps = max_steps
59
+ else:
60
+ self.max_steps = _MAX_STEPS
61
+
62
+ async def _async_run(self, prompt: str) -> AgentResponse:
63
+ chat_history: List[Message] = []
64
+ actions_taken: List[AgentAction] = []
65
+ files_involved: List[AgentFile] = []
66
+ ask = HumanMessage(content=prompt)
67
+
68
+ num_steps_taken = 0
69
+ next_step_input: Message = ask
70
+ while num_steps_taken < self.max_steps:
71
+ curr_step_output = await self._async_step(
72
+ next_step_input, chat_history, actions_taken, files_involved
73
+ )
74
+ if curr_step_output is None:
75
+ response = self._create_finished_response(chat_history, actions_taken, files_involved)
76
+ self.memory.add_message(chat_history[0])
77
+ self.memory.add_message(chat_history[-1])
78
+ return response
79
+ num_steps_taken += 1
80
+ next_step_input = curr_step_output
81
+ response = self._create_stopped_response(chat_history, actions_taken, files_involved)
82
+ return response
83
+
84
+ async def _async_step(
85
+ self,
86
+ step_input,
87
+ chat_history: List[Message],
88
+ actions: List[AgentAction],
89
+ files: List[AgentFile],
90
+ ) -> Optional[Message]:
91
+ maybe_action = await self._async_plan(step_input, chat_history)
92
+ if isinstance(maybe_action, AgentAction):
93
+ action: AgentAction = maybe_action
94
+ tool_resp = await self._async_run_tool(tool_name=action.tool_name, tool_args=action.tool_args)
95
+ actions.append(action)
96
+ files.extend(tool_resp.files)
97
+ return FunctionMessage(name=action.tool_name, content=tool_resp.json)
98
+ else:
99
+ return None
100
+
101
+ async def _async_plan(
102
+ self, input_message: Message, chat_history: List[Message]
103
+ ) -> Optional[AgentAction]:
104
+ chat_history.append(input_message)
105
+ messages = self.memory.get_messages() + chat_history
106
+ llm_resp = await self._async_run_llm(
107
+ messages=messages,
108
+ functions=self._tool_manager.get_tool_schemas(),
109
+ system=self.system_message.content if self.system_message is not None else None,
110
+ )
111
+ output_message = llm_resp.message
112
+ chat_history.append(output_message)
113
+ if output_message.function_call is not None:
114
+ return AgentAction(
115
+ tool_name=output_message.function_call["name"], # type: ignore
116
+ tool_args=output_message.function_call["arguments"],
117
+ )
118
+ else:
119
+ return None
120
+
121
+ def _create_finished_response(
122
+ self,
123
+ chat_history: List[Message],
124
+ actions: List[AgentAction],
125
+ files: List[AgentFile],
126
+ ) -> AgentResponse:
127
+ last_message = chat_history[-1]
128
+ return AgentResponse(
129
+ text=last_message.content,
130
+ chat_history=chat_history,
131
+ actions=actions,
132
+ files=files,
133
+ status="FINISHED",
134
+ )
135
+
136
+ def _create_stopped_response(
137
+ self,
138
+ chat_history: List[Message],
139
+ actions: List[AgentAction],
140
+ files: List[AgentFile],
141
+ ) -> AgentResponse:
142
+ return AgentResponse(
143
+ text="Agent run stopped early.",
144
+ chat_history=chat_history,
145
+ actions=actions,
146
+ files=files,
147
+ status="STOPPED",
148
+ )
erniebot-agent/erniebot_agent/agents/schema.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ from erniebot_agent.file_io.base import File
19
+ from erniebot_agent.messages import AIMessage, Message
20
+ from typing_extensions import Literal
21
+
22
+
23
+ @dataclass
24
+ class AgentAction(object):
25
+ """An action for an agent to execute."""
26
+
27
+ tool_name: str
28
+ tool_args: str
29
+
30
+
31
+ @dataclass
32
+ class AgentPlan(object):
33
+ """A plan that contains a list of actions."""
34
+
35
+ actions: List[AgentAction]
36
+
37
+
38
+ @dataclass
39
+ class LLMResponse(object):
40
+ """A response from an LLM."""
41
+
42
+ message: AIMessage
43
+
44
+
45
+ @dataclass
46
+ class ToolResponse(object):
47
+ """A response from a tool."""
48
+
49
+ json: str
50
+ files: List["AgentFile"]
51
+
52
+
53
+ @dataclass
54
+ class AgentResponse(object):
55
+ """The final response from an agent."""
56
+
57
+ text: str
58
+ chat_history: List[Message]
59
+ actions: List[AgentAction]
60
+ files: List["AgentFile"]
61
+ status: Union[Literal["FINISHED"], Literal["STOPPED"]]
62
+
63
+ def get_last_output_file(self) -> Optional[File]:
64
+ for agent_file in self.files[::-1]:
65
+ if agent_file.type == "output":
66
+ return agent_file.file
67
+ else:
68
+ return None
69
+
70
+ def get_output_files(self) -> List[File]:
71
+ return [agent_file.file for agent_file in self.files if agent_file.type == "output"]
72
+
73
+ def get_tool_input_output_files(self, tool_name: str) -> Tuple[List[File], List[File]]:
74
+ input_files: List[File] = []
75
+ output_files: List[File] = []
76
+ for agent_file in self.files:
77
+ if agent_file.used_by == tool_name:
78
+ if agent_file.type == "input":
79
+ input_files.append(agent_file.file)
80
+ elif agent_file.type == "output":
81
+ output_files.append(agent_file.file)
82
+ else:
83
+ raise RuntimeError("File type is neither input nor output.")
84
+ return input_files, output_files
85
+
86
+
87
+ @dataclass
88
+ class AgentFile(object):
89
+ """A file that is used by an agent."""
90
+
91
+ file: File
92
+ type: Literal["input", "output"]
93
+ used_by: str
erniebot-agent/erniebot_agent/chat_models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .erniebot import ERNIEBot
16
+
17
+ __all__ = ["ERNIEBot"]
erniebot-agent/erniebot_agent/chat_models/base.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABCMeta, abstractmethod
16
+ from typing import Any, AsyncIterator, List, Literal, Union, overload
17
+
18
+ from erniebot_agent.messages import AIMessage, AIMessageChunk, Message
19
+
20
+
21
+ class ChatModel(metaclass=ABCMeta):
22
+ """The base class of chat-optimized LLM."""
23
+
24
+ def __init__(self, model: str):
25
+ self.model = model
26
+
27
+ @overload
28
+ async def async_chat(
29
+ self, messages: List[Message], *, stream: Literal[False] = ..., **kwargs: Any
30
+ ) -> AIMessage:
31
+ ...
32
+
33
+ @overload
34
+ async def async_chat(
35
+ self, messages: List[Message], *, stream: Literal[True], **kwargs: Any
36
+ ) -> AsyncIterator[AIMessageChunk]:
37
+ ...
38
+
39
+ @overload
40
+ async def async_chat(
41
+ self, messages: List[Message], *, stream: bool, **kwargs: Any
42
+ ) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
43
+ ...
44
+
45
+ @abstractmethod
46
+ async def async_chat(
47
+ self, messages: List[Message], *, stream: bool = False, **kwargs: Any
48
+ ) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
49
+ """Asynchronously chats with the LLM.
50
+
51
+ Args:
52
+ messages (List[Message]): A list of messages.
53
+ stream (bool): Whether to use streaming generation. Defaults to False.
54
+ **kwargs: Arbitrary keyword arguments.
55
+
56
+ Returns:
57
+ If stream is False, returns a single message.
58
+ If stream is True, returns an asynchronous iterator of message chunks.
59
+ """
60
+ raise NotImplementedError
erniebot-agent/erniebot_agent/chat_models/erniebot.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import (
16
+ Any,
17
+ AsyncIterator,
18
+ Dict,
19
+ List,
20
+ Literal,
21
+ Optional,
22
+ Type,
23
+ TypeVar,
24
+ Union,
25
+ overload,
26
+ )
27
+
28
+ from erniebot_agent.chat_models.base import ChatModel
29
+ from erniebot_agent.messages import AIMessage, AIMessageChunk, FunctionCall, Message
30
+
31
+ import erniebot
32
+ from erniebot.response import EBResponse
33
+
34
+ _T = TypeVar("_T", AIMessage, AIMessageChunk)
35
+
36
+
37
+ class ERNIEBot(ChatModel):
38
+ def __init__(
39
+ self, model: str, api_type: Optional[str] = None, access_token: Optional[str] = None
40
+ ) -> None:
41
+ """Initializes an instance of the `ERNIEBot` class.
42
+
43
+ Args:
44
+ model (str): The model name. It should be "ernie-bot", "ernie-bot-turbo", "ernie-bot-8k", or
45
+ "ernie-bot-4".
46
+ api_type (Optional[str]): The API type for erniebot. It should be "aistudio" or "qianfan".
47
+ access_token (Optional[str]): The access token for erniebot.
48
+ """
49
+ super().__init__(model=model)
50
+ self.api_type = api_type
51
+ self.access_token = access_token
52
+
53
+ @overload
54
+ async def async_chat(
55
+ self,
56
+ messages: List[Message],
57
+ *,
58
+ stream: Literal[False] = ...,
59
+ functions: Optional[List[dict]] = ...,
60
+ **kwargs: Any,
61
+ ) -> AIMessage:
62
+ ...
63
+
64
+ @overload
65
+ async def async_chat(
66
+ self,
67
+ messages: List[Message],
68
+ *,
69
+ stream: Literal[True],
70
+ functions: Optional[List[dict]] = ...,
71
+ **kwargs: Any,
72
+ ) -> AsyncIterator[AIMessageChunk]:
73
+ ...
74
+
75
+ @overload
76
+ async def async_chat(
77
+ self, messages: List[Message], *, stream: bool, functions: Optional[List[dict]] = ..., **kwargs: Any
78
+ ) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
79
+ ...
80
+
81
+ async def async_chat(
82
+ self,
83
+ messages: List[Message],
84
+ *,
85
+ stream: bool = False,
86
+ functions: Optional[List[dict]] = None,
87
+ **kwargs: Any,
88
+ ) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
89
+ """Asynchronously chats with the ERNIE Bot model.
90
+
91
+ Args:
92
+ messages (List[Message]): A list of messages.
93
+ stream (bool): Whether to use streaming generation. Defaults to False.
94
+ functions (Optional[List[dict]]): The function definitions to be used by the model.
95
+ Defaults to None.
96
+ **kwargs: Keyword arguments, such as `top_p`, `temperature`, `penalty_score`, and `system`.
97
+
98
+ Returns:
99
+ If `stream` is False, returns a single message.
100
+ If `stream` is True, returns an asynchronous iterator of message chunks.
101
+ """
102
+ cfg_dict: Dict[str, Any] = {"model": self.model, "_config_": {}}
103
+ if self.api_type is not None:
104
+ cfg_dict["_config_"]["api_type"] = self.api_type
105
+ if self.access_token is not None:
106
+ cfg_dict["_config_"]["access_token"] = self.access_token
107
+
108
+ # TODO: process system message
109
+ cfg_dict["messages"] = [m.to_dict() for m in messages]
110
+ if functions is not None:
111
+ cfg_dict["functions"] = functions
112
+
113
+ name_list = ["top_p", "temperature", "penalty_score", "system"]
114
+ for name in name_list:
115
+ if name in kwargs:
116
+ cfg_dict[name] = kwargs[name]
117
+
118
+ # TODO: Improve this when erniebot typing issue is fixed.
119
+ response: Any = await erniebot.ChatCompletion.acreate(stream=stream, **cfg_dict)
120
+ if isinstance(response, EBResponse):
121
+ return self.convert_response_to_output(response, AIMessage)
122
+ else:
123
+ return (self.convert_response_to_output(resp, AIMessageChunk) async for resp in response)
124
+
125
+ @staticmethod
126
+ def convert_response_to_output(response: EBResponse, output_type: Type[_T]) -> _T:
127
+ if hasattr(response, "function_call"):
128
+ function_call = FunctionCall(
129
+ name=response.function_call["name"],
130
+ thoughts=response.function_call["thoughts"],
131
+ arguments=response.function_call["arguments"],
132
+ )
133
+ return output_type(content="", function_call=function_call, token_usage=response.usage)
134
+ else:
135
+ return output_type(content=response.result, function_call=None, token_usage=response.usage)
erniebot-agent/erniebot_agent/extensions/langchain/chat_models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .erniebot import ErnieBotChat
erniebot-agent/erniebot_agent/extensions/langchain/chat_models/erniebot.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import (
4
+ Any,
5
+ AsyncIterator,
6
+ Callable,
7
+ Dict,
8
+ Iterator,
9
+ List,
10
+ Mapping,
11
+ Optional,
12
+ Type,
13
+ Union,
14
+ )
15
+
16
+ from langchain.callbacks.manager import (
17
+ AsyncCallbackManagerForLLMRun,
18
+ CallbackManagerForLLMRun,
19
+ )
20
+ from langchain.chat_models.base import BaseChatModel
21
+ from langchain.llms.base import create_base_retry_decorator
22
+ from langchain.pydantic_v1 import Field, root_validator
23
+ from langchain.schema import ChatGeneration, ChatResult
24
+ from langchain.schema.messages import (
25
+ AIMessage,
26
+ AIMessageChunk,
27
+ BaseMessage,
28
+ ChatMessage,
29
+ FunctionMessage,
30
+ HumanMessage,
31
+ SystemMessage,
32
+ )
33
+ from langchain.schema.output import ChatGenerationChunk
34
+ from langchain.utils import get_from_dict_or_env
35
+
36
+ _MessageDict = Dict[str, Any]
37
+
38
+
39
+ class ErnieBotChat(BaseChatModel):
40
+ """ERNIE Bot Chat large language models API.
41
+
42
+ To use, you should have the ``erniebot`` python package installed, and the
43
+ environment variable ``EB_ACCESS_TOKEN`` set with your AI Studio access token.
44
+
45
+ Example:
46
+ .. code-block:: python
47
+ from erniebot_agent.extensions.langchain.chat_models import ErnieBotChat
48
+ erniebot_chat = ErnieBotChat(model="ernie-bot")
49
+ """
50
+
51
+ client: Any = None
52
+ max_retries: int = 6
53
+ """Maximum number of retries to make when generating."""
54
+ aistudio_access_token: Optional[str] = None
55
+ """AI Studio access token."""
56
+ streaming: Optional[bool] = False
57
+ """Whether to stream the results or not."""
58
+ model: str = "ernie-bot"
59
+ """Model to use."""
60
+ top_p: Optional[float] = 0.8
61
+ """Parameter of nucleus sampling that affects the diversity of generated content."""
62
+ temperature: Optional[float] = 0.95
63
+ """Sampling temperature to use."""
64
+ penalty_score: Optional[float] = 1
65
+ """Penalty assigned to tokens that have been generated."""
66
+ request_timeout: Optional[int] = 60
67
+ """How many seconds to wait for the server to send data before giving up."""
68
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
69
+ """Holds any model parameters valid for `create` call not explicitly specified."""
70
+
71
+ ernie_client_id: Optional[str] = None
72
+ ernie_client_secret: Optional[str] = None
73
+ """For raising deprecation warnings."""
74
+
75
+ @property
76
+ def _default_params(self) -> Dict[str, Any]:
77
+ """Get the default parameters for calling ERNIE Bot API."""
78
+ normal_params = {
79
+ "model": self.model,
80
+ "top_p": self.top_p,
81
+ "temperature": self.temperature,
82
+ "penalty_score": self.penalty_score,
83
+ "request_timeout": self.request_timeout,
84
+ }
85
+ return {**normal_params, **self.model_kwargs}
86
+
87
+ @property
88
+ def _identifying_params(self) -> Dict[str, Any]:
89
+ return self._default_params
90
+
91
+ @property
92
+ def _invocation_params(self) -> Dict[str, Any]:
93
+ """Get the parameters used to invoke the model."""
94
+ auth_cfg: Dict[str, Optional[str]] = {
95
+ "api_type": "aistudio",
96
+ "access_token": self.aistudio_access_token,
97
+ }
98
+ return {**{"_config_": auth_cfg}, **self._default_params}
99
+
100
+ @property
101
+ def _llm_type(self) -> str:
102
+ """Return type of llm."""
103
+ return "erniebot"
104
+
105
+ @root_validator()
106
+ def validate_enviroment(cls, values: Dict) -> Dict:
107
+ values["aistudio_access_token"] = get_from_dict_or_env(
108
+ values,
109
+ "aistudio_access_token",
110
+ "EB_ACCESS_TOKEN",
111
+ )
112
+
113
+ try:
114
+ import erniebot
115
+
116
+ values["client"] = erniebot.ChatCompletion
117
+ except ImportError:
118
+ raise ImportError(
119
+ "Could not import erniebot python package. Please install it with `pip install erniebot`."
120
+ )
121
+ return values
122
+
123
+ def _generate(
124
+ self,
125
+ messages: List[BaseMessage],
126
+ stop: Optional[List[str]] = None,
127
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
128
+ **kwargs: Any,
129
+ ) -> ChatResult:
130
+ if self.streaming:
131
+ chunks = self._stream(messages, stop=stop, run_manager=run_manager, **kwargs)
132
+ generation: Optional[ChatGenerationChunk] = None
133
+ for chunk in chunks:
134
+ if generation is None:
135
+ generation = chunk
136
+ else:
137
+ generation += chunk
138
+ assert generation is not None
139
+ return ChatResult(generations=[generation])
140
+ else:
141
+ params = self._invocation_params
142
+ params.update(kwargs)
143
+ params["messages"] = self._convert_messages_to_dicts(messages)
144
+ system_prompt = self._build_system_prompt_from_messages(messages)
145
+ if system_prompt is not None:
146
+ params["system"] = system_prompt
147
+ params["stream"] = False
148
+ response = _create_completion_with_retry(self, run_manager=run_manager, **params)
149
+ return self._build_chat_result_from_response(response)
150
+
151
+ async def _agenerate(
152
+ self,
153
+ messages: List[BaseMessage],
154
+ stop: Optional[List[str]] = None,
155
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
156
+ **kwargs: Any,
157
+ ) -> ChatResult:
158
+ if self.streaming:
159
+ chunks = self._astream(messages, stop=stop, run_manager=run_manager, **kwargs)
160
+ generation: Optional[ChatGenerationChunk] = None
161
+ async for chunk in chunks:
162
+ if generation is None:
163
+ generation = chunk
164
+ else:
165
+ generation += chunk
166
+ assert generation is not None
167
+ return ChatResult(generations=[generation])
168
+ else:
169
+ params = self._invocation_params
170
+ params.update(kwargs)
171
+ params["messages"] = self._convert_messages_to_dicts(messages)
172
+ system_prompt = self._build_system_prompt_from_messages(messages)
173
+ if system_prompt is not None:
174
+ params["system"] = system_prompt
175
+ params["stream"] = False
176
+ response = await _acreate_completion_with_retry(self, run_manager=run_manager, **params)
177
+ return self._build_chat_result_from_response(response)
178
+
179
+ def _stream(
180
+ self,
181
+ messages: List[BaseMessage],
182
+ stop: Optional[List[str]] = None,
183
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
184
+ **kwargs: Any,
185
+ ) -> Iterator[ChatGenerationChunk]:
186
+ if stop is not None:
187
+ raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
188
+ params = self._invocation_params
189
+ params.update(kwargs)
190
+ params["messages"] = self._convert_messages_to_dicts(messages)
191
+ system_prompt = self._build_system_prompt_from_messages(messages)
192
+ if system_prompt is not None:
193
+ params["system"] = system_prompt
194
+ params["stream"] = True
195
+ for resp in _create_completion_with_retry(self, run_manager=run_manager, **params):
196
+ chunk = self._build_chunk_from_response(resp)
197
+ yield chunk
198
+ if run_manager:
199
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
200
+
201
+ async def _astream(
202
+ self,
203
+ messages: List[BaseMessage],
204
+ stop: Optional[List[str]] = None,
205
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
206
+ **kwargs: Any,
207
+ ) -> AsyncIterator[ChatGenerationChunk]:
208
+ if stop is not None:
209
+ raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
210
+ params = self._invocation_params
211
+ params.update(kwargs)
212
+ params["messages"] = self._convert_messages_to_dicts(messages)
213
+ system_prompt = self._build_system_prompt_from_messages(messages)
214
+ if system_prompt is not None:
215
+ params["system"] = system_prompt
216
+ params["stream"] = True
217
+ async for resp in await _acreate_completion_with_retry(self, run_manager=run_manager, **params):
218
+ chunk = self._build_chunk_from_response(resp)
219
+ yield chunk
220
+ if run_manager:
221
+ await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
222
+
223
+ def _build_chat_result_from_response(self, response: Mapping[str, Any]) -> ChatResult:
224
+ message_dict = self._build_dict_from_response(response)
225
+ generation = ChatGeneration(
226
+ message=self._convert_dict_to_message(message_dict),
227
+ generation_info=dict(finish_reason="stop"),
228
+ )
229
+ token_usage = response.get("usage", {})
230
+ llm_output = {"token_usage": token_usage, "model_name": self.model}
231
+ return ChatResult(generations=[generation], llm_output=llm_output)
232
+
233
+ def _build_chunk_from_response(self, response: Mapping[str, Any]) -> ChatGenerationChunk:
234
+ message_dict = self._build_dict_from_response(response)
235
+ message = self._convert_dict_to_message(message_dict)
236
+ msg_chunk = AIMessageChunk(
237
+ content=message.content,
238
+ additional_kwargs=message.additional_kwargs,
239
+ )
240
+ return ChatGenerationChunk(message=msg_chunk)
241
+
242
+ def _build_dict_from_response(self, response: Mapping[str, Any]) -> _MessageDict:
243
+ message_dict: _MessageDict = {"role": "assistant"}
244
+ if "function_call" in response:
245
+ message_dict["content"] = None
246
+ message_dict["function_call"] = response["function_call"]
247
+ else:
248
+ message_dict["content"] = response["result"]
249
+ return message_dict
250
+
251
+ def _build_system_prompt_from_messages(self, messages: List[BaseMessage]) -> Optional[str]:
252
+ system_message_content_list: List[str] = []
253
+ for msg in messages:
254
+ if isinstance(msg, SystemMessage):
255
+ if isinstance(msg.content, str):
256
+ system_message_content_list.append(msg.content)
257
+ else:
258
+ raise TypeError
259
+ if len(system_message_content_list) > 0:
260
+ return "\n".join(system_message_content_list)
261
+ else:
262
+ return None
263
+
264
+ def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> List[dict]:
265
+ erniebot_messages = []
266
+ for msg in messages:
267
+ if isinstance(msg, SystemMessage):
268
+ # Ignore system messages, as we handle them elsewhere.
269
+ continue
270
+ eb_msg = self._convert_message_to_dict(msg)
271
+ erniebot_messages.append(eb_msg)
272
+ return erniebot_messages
273
+
274
+ @staticmethod
275
+ def _convert_dict_to_message(message_dict: _MessageDict) -> BaseMessage:
276
+ role = message_dict["role"]
277
+ if role == "user":
278
+ return HumanMessage(content=message_dict["content"])
279
+ elif role == "assistant":
280
+ content = message_dict["content"] or ""
281
+ if message_dict.get("function_call"):
282
+ additional_kwargs = {"function_call": dict(message_dict["function_call"])}
283
+ else:
284
+ additional_kwargs = {}
285
+ return AIMessage(content=content, additional_kwargs=additional_kwargs)
286
+ elif role == "function":
287
+ return FunctionMessage(content=message_dict["content"], name=message_dict["name"])
288
+ else:
289
+ return ChatMessage(content=message_dict["content"], role=role)
290
+
291
+ @staticmethod
292
+ def _convert_message_to_dict(message: BaseMessage) -> _MessageDict:
293
+ message_dict: _MessageDict
294
+ if isinstance(message, ChatMessage):
295
+ message_dict = {"role": message.role, "content": message.content}
296
+ elif isinstance(message, HumanMessage):
297
+ message_dict = {"role": "user", "content": message.content}
298
+ elif isinstance(message, AIMessage):
299
+ message_dict = {"role": "assistant", "content": message.content}
300
+ if "function_call" in message.additional_kwargs:
301
+ message_dict["function_call"] = message.additional_kwargs["function_call"]
302
+ if message_dict["content"] == "":
303
+ message_dict["content"] = None
304
+ elif isinstance(message, FunctionMessage):
305
+ message_dict = {
306
+ "role": "function",
307
+ "content": message.content,
308
+ "name": message.name,
309
+ }
310
+ else:
311
+ raise TypeError(f"Got unknown type {message}")
312
+
313
+ return message_dict
314
+
315
+
316
+ def _create_completion_with_retry(
317
+ llm: ErnieBotChat,
318
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
319
+ **kwargs: Any,
320
+ ) -> Any:
321
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
322
+
323
+ @retry_decorator
324
+ def _client_create(**kwargs: Any) -> Any:
325
+ return llm.client.create(**kwargs)
326
+
327
+ return _client_create(**kwargs)
328
+
329
+
330
+ async def _acreate_completion_with_retry(
331
+ llm: ErnieBotChat,
332
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
333
+ **kwargs: Any,
334
+ ) -> Any:
335
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
336
+
337
+ @retry_decorator
338
+ async def _client_acreate(**kwargs: Any) -> Any:
339
+ return await llm.client.acreate(**kwargs)
340
+
341
+ return await _client_acreate(**kwargs)
342
+
343
+
344
+ def _create_retry_decorator(
345
+ llm: ErnieBotChat,
346
+ run_manager: Optional[Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]] = None,
347
+ ) -> Callable[[Any], Any]:
348
+ import erniebot
349
+
350
+ errors: List[Type[BaseException]] = [
351
+ erniebot.errors.TimeoutError,
352
+ erniebot.errors.RequestLimitError,
353
+ ]
354
+ return create_base_retry_decorator(
355
+ error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
356
+ )
erniebot-agent/erniebot_agent/extensions/langchain/embeddings/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ernie import ErnieEmbeddings
erniebot-agent/erniebot_agent/extensions/langchain/embeddings/ernie.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from langchain.pydantic_v1 import BaseModel, root_validator
6
+ from langchain.schema.embeddings import Embeddings
7
+ from langchain.utils import get_from_dict_or_env
8
+
9
+
10
+ class ErnieEmbeddings(BaseModel, Embeddings):
11
+ """ERNIE embedding models.
12
+
13
+ To use, you should have the ``erniebot`` python package installed, and the
14
+ environment variable ``EB_ACCESS_TOKEN`` set with your AI Studio access token.
15
+
16
+ Example:
17
+ .. code-block:: python
18
+ from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
19
+ ernie_embeddings = ErnieEmbeddings()
20
+ """
21
+
22
+ client: Any = None
23
+ max_retries: int = 6
24
+ """Maximum number of retries to make when generating."""
25
+ chunk_size: int = 16
26
+ """Chunk size to use when the input is a list of texts."""
27
+ aistudio_access_token: Optional[str] = None
28
+ """AI Studio access token."""
29
+ model: str = "ernie-text-embedding"
30
+ """Model to use."""
31
+ request_timeout: Optional[int] = 60
32
+ """How many seconds to wait for the server to send data before giving up."""
33
+
34
+ ernie_client_id: Optional[str] = None
35
+ ernie_client_secret: Optional[str] = None
36
+ """For raising deprecation warnings."""
37
+
38
+ @root_validator()
39
+ def validate_environment(cls, values: Dict) -> Dict:
40
+ values["aistudio_access_token"] = get_from_dict_or_env(
41
+ values,
42
+ "aistudio_access_token",
43
+ "EB_ACCESS_TOKEN",
44
+ )
45
+
46
+ try:
47
+ import erniebot
48
+
49
+ values["client"] = erniebot.Embedding
50
+ except ImportError:
51
+ raise ImportError(
52
+ "Could not import erniebot python package. Please install it with `pip install erniebot`."
53
+ )
54
+ return values
55
+
56
+ def embed_query(self, text: str) -> List[float]:
57
+ resp = self.embed_documents([text])
58
+ return resp[0]
59
+
60
+ async def aembed_query(self, text: str) -> List[float]:
61
+ embeddings = await self.aembed_documents([text])
62
+ return embeddings[0]
63
+
64
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
65
+ text_in_chunks = [texts[i : i + self.chunk_size] for i in range(0, len(texts), self.chunk_size)]
66
+ lst = []
67
+ for chunk in text_in_chunks:
68
+ resp = self.client.create(_config_=self._get_auth_config(), input=chunk, model=self.model)
69
+ lst.extend([res["embedding"] for res in resp["data"]])
70
+ return lst
71
+
72
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
73
+ text_in_chunks = [texts[i : i + self.chunk_size] for i in range(0, len(texts), self.chunk_size)]
74
+ lst = []
75
+ for chunk in text_in_chunks:
76
+ resp = await self.client.acreate(_config_=self._get_auth_config(), input=chunk, model=self.model)
77
+ for res in resp["data"]:
78
+ lst.extend([res["embedding"]])
79
+ return lst
80
+
81
+ def _get_auth_config(self) -> dict:
82
+ return {"api_type": "aistudio", "access_token": self.aistudio_access_token}
erniebot-agent/erniebot_agent/extensions/langchain/llms/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .erniebot import ErnieBot
erniebot-agent/erniebot_agent/extensions/langchain/llms/erniebot.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import (
4
+ Any,
5
+ AsyncIterator,
6
+ Callable,
7
+ Dict,
8
+ Iterator,
9
+ List,
10
+ Mapping,
11
+ Optional,
12
+ Type,
13
+ Union,
14
+ )
15
+
16
+ from langchain.callbacks.manager import (
17
+ AsyncCallbackManagerForLLMRun,
18
+ CallbackManagerForLLMRun,
19
+ )
20
+ from langchain.llms.base import LLM, create_base_retry_decorator
21
+ from langchain.llms.utils import enforce_stop_tokens
22
+ from langchain.pydantic_v1 import Field, root_validator
23
+ from langchain.schema.output import GenerationChunk
24
+ from langchain.utils import get_from_dict_or_env
25
+
26
+
27
+ class ErnieBot(LLM):
28
+ """ERNIE Bot large language models.
29
+
30
+ To use, you should have the ``erniebot`` python package installed, and the
31
+ environment variable ``EB_ACCESS_TOKEN`` set with your AI Studio access token.
32
+
33
+ Example:
34
+ .. code-block:: python
35
+
36
+ from erniebot_agent.extensions.langchain.llms import ErnieBot
37
+ erniebot = ErnieBot(model="ernie-bot")
38
+ """
39
+
40
+ client: Any = None
41
+ max_retries: int = 6
42
+ """Maximum number of retries to make when generating."""
43
+ aistudio_access_token: Optional[str] = None
44
+ """AI Studio access token."""
45
+ streaming: Optional[bool] = False
46
+ """Whether to stream the results or not."""
47
+ model: str = "ernie-bot"
48
+ """Model to use."""
49
+ top_p: Optional[float] = 0.8
50
+ """Parameter of nucleus sampling that affects the diversity of generated content."""
51
+ temperature: Optional[float] = 0.95
52
+ """Sampling temperature to use."""
53
+ penalty_score: Optional[float] = 1
54
+ """Penalty assigned to tokens that have been generated."""
55
+ request_timeout: Optional[int] = 60
56
+ """How many seconds to wait for the server to send data before giving up."""
57
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
58
+ """Holds any model parameters valid for `create` call not explicitly specified."""
59
+
60
+ @property
61
+ def _default_params(self) -> Dict[str, Any]:
62
+ """Get the default parameters for calling ERNIE Bot API."""
63
+ normal_params = {
64
+ "model": self.model,
65
+ "top_p": self.top_p,
66
+ "temperature": self.temperature,
67
+ "penalty_score": self.penalty_score,
68
+ "request_timeout": self.request_timeout,
69
+ }
70
+ return {**normal_params, **self.model_kwargs}
71
+
72
+ @property
73
+ def _identifying_params(self) -> Dict[str, Any]:
74
+ return self._default_params
75
+
76
+ @property
77
+ def _invocation_params(self) -> Dict[str, Any]:
78
+ """Get the parameters used to invoke the model."""
79
+ auth_cfg: Dict[str, Optional[str]] = {
80
+ "api_type": "aistudio",
81
+ "access_token": self.aistudio_access_token,
82
+ }
83
+ return {**{"_config_": auth_cfg}, **self._default_params}
84
+
85
+ @property
86
+ def _llm_type(self) -> str:
87
+ """Return type of llm."""
88
+ return "erniebot"
89
+
90
+ @root_validator()
91
+ def validate_enviroment(cls, values: Dict) -> Dict:
92
+ values["aistudio_access_token"] = get_from_dict_or_env(
93
+ values,
94
+ "aistudio_access_token",
95
+ "EB_ACCESS_TOKEN",
96
+ )
97
+
98
+ try:
99
+ import erniebot
100
+
101
+ values["client"] = erniebot.ChatCompletion
102
+ except ImportError:
103
+ raise ImportError(
104
+ "Could not import erniebot python package. Please install it with `pip install erniebot`."
105
+ )
106
+ return values
107
+
108
+ def _call(
109
+ self,
110
+ prompt: str,
111
+ stop: Optional[List[str]] = None,
112
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
113
+ **kwargs: Any,
114
+ ) -> str:
115
+ if self.streaming:
116
+ text = ""
117
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
118
+ text += chunk.text
119
+ return text
120
+ else:
121
+ params = self._invocation_params
122
+ params.update(kwargs)
123
+ params["messages"] = [self._build_user_message_from_prompt(prompt)]
124
+ params["stream"] = False
125
+ response = _create_completion_with_retry(self, run_manager=run_manager, **params)
126
+ text = response["result"]
127
+ if stop is not None:
128
+ text = enforce_stop_tokens(text, stop)
129
+ return text
130
+
131
+ async def _acall(
132
+ self,
133
+ prompt: str,
134
+ stop: Optional[List[str]] = None,
135
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
136
+ **kwargs: Any,
137
+ ) -> str:
138
+ if self.streaming:
139
+ text = ""
140
+ async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
141
+ text += chunk.text
142
+ return text
143
+ else:
144
+ params = self._invocation_params
145
+ params.update(kwargs)
146
+ params["messages"] = [self._build_user_message_from_prompt(prompt)]
147
+ params["stream"] = False
148
+ response = await _acreate_completion_with_retry(self, run_manager=run_manager, **params)
149
+ text = response["result"]
150
+ if stop is not None:
151
+ text = enforce_stop_tokens(text, stop)
152
+ return text
153
+
154
+ def _stream(
155
+ self,
156
+ prompt: str,
157
+ stop: Optional[List[str]] = None,
158
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
159
+ **kwargs: Any,
160
+ ) -> Iterator[GenerationChunk]:
161
+ if stop is not None:
162
+ raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
163
+ params = self._invocation_params
164
+ params.update(kwargs)
165
+ params["messages"] = [self._build_user_message_from_prompt(prompt)]
166
+ params["stream"] = True
167
+ for resp in _create_completion_with_retry(self, run_manager=run_manager, **params):
168
+ chunk = self._build_chunk_from_response(resp)
169
+ yield chunk
170
+ if run_manager:
171
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
172
+
173
+ async def _astream(
174
+ self,
175
+ prompt: str,
176
+ stop: Optional[List[str]] = None,
177
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
178
+ **kwargs: Any,
179
+ ) -> AsyncIterator[GenerationChunk]:
180
+ if stop is not None:
181
+ raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
182
+ params = self._invocation_params
183
+ params.update(kwargs)
184
+ params["messages"] = [self._build_user_message_from_prompt(prompt)]
185
+ params["stream"] = True
186
+ async for resp in await _acreate_completion_with_retry(self, run_manager=run_manager, **params):
187
+ chunk = self._build_chunk_from_response(resp)
188
+ yield chunk
189
+ if run_manager:
190
+ await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
191
+
192
+ def _build_chunk_from_response(self, response: Mapping[str, Any]) -> GenerationChunk:
193
+ return GenerationChunk(text=response["result"])
194
+
195
+ def _build_user_message_from_prompt(self, prompt: str) -> Dict[str, str]:
196
+ return {"role": "user", "content": prompt}
197
+
198
+
199
+ def _create_completion_with_retry(
200
+ llm: ErnieBot,
201
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
202
+ **kwargs: Any,
203
+ ) -> Any:
204
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
205
+
206
+ @retry_decorator
207
+ def _client_create(**kwargs: Any) -> Any:
208
+ return llm.client.create(**kwargs)
209
+
210
+ return _client_create(**kwargs)
211
+
212
+
213
+ async def _acreate_completion_with_retry(
214
+ llm: ErnieBot,
215
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
216
+ **kwargs: Any,
217
+ ) -> Any:
218
+ retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
219
+
220
+ @retry_decorator
221
+ async def _client_acreate(**kwargs: Any) -> Any:
222
+ return await llm.client.acreate(**kwargs)
223
+
224
+ return await _client_acreate(**kwargs)
225
+
226
+
227
+ def _create_retry_decorator(
228
+ llm: ErnieBot,
229
+ run_manager: Optional[Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]] = None,
230
+ ) -> Callable[[Any], Any]:
231
+ import erniebot
232
+
233
+ errors: List[Type[BaseException]] = [
234
+ erniebot.errors.TimeoutError,
235
+ erniebot.errors.RequestLimitError,
236
+ ]
237
+ return create_base_retry_decorator(
238
+ error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
239
+ )
erniebot-agent/erniebot_agent/file_io/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
erniebot-agent/erniebot_agent/file_io/base.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+
17
+
18
+ class File(metaclass=abc.ABCMeta):
19
+ def __init__(self, id: str, filename: str, created_at: int) -> None:
20
+ super().__init__()
21
+ self.id = id
22
+ self.filename = filename
23
+ self.created_at = created_at
24
+
25
+ def __eq__(self, other: object) -> bool:
26
+ if isinstance(other, File):
27
+ return self.id == other.id
28
+ else:
29
+ return False
30
+
31
+ def __repr__(self) -> str:
32
+ attrs_str = self._get_attrs_str()
33
+ return f"<{self.__class__.__name__} {attrs_str}>"
34
+
35
+ @abc.abstractmethod
36
+ async def read_contents(self) -> bytes:
37
+ raise NotImplementedError
38
+
39
+ def _get_attrs_str(self) -> str:
40
+ return ", ".join(
41
+ [
42
+ f"id: {repr(self.id)}",
43
+ f"filename: {repr(self.filename)}",
44
+ f"created_at: {repr(self.created_at)}",
45
+ ]
46
+ )
erniebot-agent/erniebot_agent/file_io/file_manager.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import pathlib
17
+ import uuid
18
+ from typing import Literal, Optional, Union, overload
19
+
20
+ import anyio
21
+ from erniebot_agent.file_io.base import File
22
+ from erniebot_agent.file_io.file_registry import FileRegistry, get_file_registry
23
+ from erniebot_agent.file_io.local_file import LocalFile, create_local_file_from_path
24
+ from erniebot_agent.file_io.remote_file import RemoteFile, RemoteFileClient
25
+ from erniebot_agent.utils.temp_file import create_tracked_temp_dir
26
+ from typing_extensions import TypeAlias
27
+
28
+ _PathType: TypeAlias = Union[str, os.PathLike]
29
+
30
+
31
+ class FileManager(object):
32
+ _remote_file_client: Optional[RemoteFileClient]
33
+
34
+ def __init__(
35
+ self,
36
+ remote_file_client: Optional[RemoteFileClient] = None,
37
+ *,
38
+ auto_register: bool = True,
39
+ save_dir: Optional[_PathType] = None,
40
+ ) -> None:
41
+ super().__init__()
42
+ if remote_file_client is not None:
43
+ self._remote_file_client = remote_file_client
44
+ else:
45
+ self._remote_file_client = None
46
+ self._auto_register = auto_register
47
+ if save_dir is not None:
48
+ self._save_dir = pathlib.Path(save_dir)
49
+ else:
50
+ # This can be done lazily, but we need to be careful about race conditions.
51
+ self._save_dir = create_tracked_temp_dir()
52
+
53
+ self._file_registry = get_file_registry()
54
+
55
+ @property
56
+ def registry(self) -> FileRegistry:
57
+ return self._file_registry
58
+
59
+ @property
60
+ def remote_file_client(self) -> RemoteFileClient:
61
+ if self._remote_file_client is None:
62
+ raise AttributeError("No remote file client is set.")
63
+ else:
64
+ return self._remote_file_client
65
+
66
+ @overload
67
+ async def create_file_from_path(
68
+ self, file_path: _PathType, *, file_type: Literal["local"] = ...
69
+ ) -> LocalFile:
70
+ ...
71
+
72
+ @overload
73
+ async def create_file_from_path(
74
+ self, file_path: _PathType, *, file_type: Literal["remote"]
75
+ ) -> RemoteFile:
76
+ ...
77
+
78
+ async def create_file_from_path(
79
+ self, file_path: _PathType, *, file_type: Literal["local", "remote"] = "local"
80
+ ) -> Union[LocalFile, RemoteFile]:
81
+ file: Union[LocalFile, RemoteFile]
82
+ if file_type == "local":
83
+ file = await self.create_local_file_from_path(file_path)
84
+ elif file_type == "remote":
85
+ file = await self.create_remote_file_from_path(file_path)
86
+ else:
87
+ raise ValueError(f"Unsupported file type: {file_type}")
88
+ return file
89
+
90
+ async def create_local_file_from_path(self, file_path: _PathType) -> LocalFile:
91
+ file = create_local_file_from_path(pathlib.Path(file_path))
92
+ self._file_registry.register_file(file)
93
+ return file
94
+
95
+ async def create_remote_file_from_path(self, file_path: _PathType) -> RemoteFile:
96
+ file = await self.remote_file_client.upload_file(pathlib.Path(file_path))
97
+ if self._auto_register:
98
+ self._file_registry.register_file(file)
99
+ return file
100
+
101
+ @overload
102
+ async def create_file_from_bytes(
103
+ self, file_contents: bytes, filename: str, *, file_type: Literal["local"] = ...
104
+ ) -> LocalFile:
105
+ ...
106
+
107
+ @overload
108
+ async def create_file_from_bytes(
109
+ self, file_contents: bytes, filename: str, *, file_type: Literal["remote"]
110
+ ) -> RemoteFile:
111
+ ...
112
+
113
+ async def create_file_from_bytes(
114
+ self, file_contents: bytes, filename: str, *, file_type: Literal["local", "remote"] = "local"
115
+ ) -> Union[LocalFile, RemoteFile]:
116
+ # Can we do this with in-memory files?
117
+ file_path = self._fs_create_file(
118
+ prefix=pathlib.PurePath(filename).stem, suffix=pathlib.PurePath(filename).suffix
119
+ )
120
+ async with await anyio.open_file(file_path, "wb") as f:
121
+ await f.write(file_contents)
122
+ file = await self.create_file_from_path(file_path, file_type=file_type)
123
+ return file
124
+
125
+ async def retrieve_remote_file_by_id(self, file_id: str) -> RemoteFile:
126
+ file = await self.remote_file_client.retrieve_file(file_id)
127
+ if self._auto_register:
128
+ self._file_registry.register_file(file)
129
+ return file
130
+
131
+ def look_up_file_by_id(self, file_id: str) -> Optional[File]:
132
+ return self._file_registry.look_up_file(file_id)
133
+
134
+ def _fs_create_file(self, prefix: Optional[str] = None, suffix: Optional[str] = None) -> pathlib.Path:
135
+ filename = f"{prefix or ''}{str(uuid.uuid4())}{suffix or ''}"
136
+ file_path = self._save_dir / filename
137
+ file_path.touch()
138
+ return file_path
erniebot-agent/erniebot_agent/file_io/file_registry.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import threading
16
+ from typing import Dict, List, Optional
17
+
18
+ from erniebot_agent.file_io.base import File
19
+ from erniebot_agent.utils.misc import Singleton
20
+
21
+
22
+ class FileRegistry(metaclass=Singleton):
23
+ def __init__(self) -> None:
24
+ super().__init__()
25
+ self._id_to_file: Dict[str, File] = {}
26
+ self._lock = threading.Lock()
27
+
28
+ def register_file(self, file: File) -> None:
29
+ file_id = file.id
30
+ with self._lock:
31
+ # Re-registering an existing file is allowed.
32
+ # We simply update the registry.
33
+ self._id_to_file[file_id] = file
34
+
35
+ def unregister_file(self, file: File) -> None:
36
+ file_id = file.id
37
+ with self._lock:
38
+ if file_id not in self._id_to_file:
39
+ raise RuntimeError(f"ID {repr(file_id)} is not registered.")
40
+ self._id_to_file.pop(file_id)
41
+
42
+ def look_up_file(self, file_id: str) -> Optional[File]:
43
+ with self._lock:
44
+ return self._id_to_file.get(file_id, None)
45
+
46
+ def list_files(self) -> List[File]:
47
+ with self._lock:
48
+ return list(self._id_to_file.values())
49
+
50
+
51
+ _file_registry = FileRegistry()
52
+
53
+
54
+ def get_file_registry() -> FileRegistry:
55
+ return _file_registry
erniebot-agent/erniebot_agent/file_io/local_file.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pathlib
16
+ import time
17
+ import uuid
18
+
19
+ import anyio
20
+ from erniebot_agent.file_io.base import File
21
+ from erniebot_agent.file_io.protocol import (
22
+ build_local_file_id_from_uuid,
23
+ is_local_file_id,
24
+ )
25
+ from erniebot_agent.utils.logging import logger
26
+
27
+
28
+ class LocalFile(File):
29
+ def __init__(self, id: str, filename: str, created_at: int, path: pathlib.Path) -> None:
30
+ if not is_local_file_id(id):
31
+ raise ValueError("Invalid file ID: {id}")
32
+ super().__init__(id=id, filename=filename, created_at=created_at)
33
+ self.path = path
34
+
35
+ async def read_contents(self) -> bytes:
36
+ return await anyio.Path(self.path).read_bytes()
37
+
38
+ def _get_attrs_str(self) -> str:
39
+ attrs_str = super()._get_attrs_str()
40
+ attrs_str += f", path: {repr(self.path)}"
41
+ return attrs_str
42
+
43
+
44
+ def create_local_file_from_path(file_path: pathlib.Path) -> LocalFile:
45
+ if not file_path.exists():
46
+ logger.warn("File %s does not exist.", file_path)
47
+ file_id = _generate_local_file_id()
48
+ filename = file_path.name
49
+ created_at = int(time.time())
50
+ file = LocalFile(id=file_id, filename=filename, created_at=created_at, path=file_path)
51
+ return file
52
+
53
+
54
+ def _generate_local_file_id():
55
+ return build_local_file_id_from_uuid(str(uuid.uuid1()))
erniebot-agent/erniebot_agent/file_io/protocol.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import List
17
+
18
+ _LOCAL_FILE_ID_PREFIX = "file-local-"
19
+ _REMOTE_FILE_ID_PREFIX = "file-remote-"
20
+ _UUID_PATTERN = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
21
+ _LOCAL_FILE_ID_PATTERN = _LOCAL_FILE_ID_PREFIX + _UUID_PATTERN
22
+ _REMOTE_FILE_ID_PATTERN = _REMOTE_FILE_ID_PREFIX + _UUID_PATTERN
23
+
24
+ _compiled_local_file_id_pattern = re.compile(_LOCAL_FILE_ID_PATTERN)
25
+ _compiled_remote_file_id_pattern = re.compile(_REMOTE_FILE_ID_PATTERN)
26
+
27
+
28
+ def build_local_file_id_from_uuid(uuid: str) -> str:
29
+ return _LOCAL_FILE_ID_PREFIX + uuid
30
+
31
+
32
+ def build_remote_file_id_from_uuid(uuid: str) -> str:
33
+ return _REMOTE_FILE_ID_PREFIX + uuid
34
+
35
+
36
+ def is_file_id(str_: str) -> bool:
37
+ return is_local_file_id(str_) or is_remote_file_id(str_)
38
+
39
+
40
+ def is_local_file_id(str_: str) -> bool:
41
+ return _compiled_local_file_id_pattern.fullmatch(str_) is not None
42
+
43
+
44
+ def is_remote_file_id(str_: str) -> bool:
45
+ return _compiled_remote_file_id_pattern.fullmatch(str_) is not None
46
+
47
+
48
+ def extract_file_ids(str_: str) -> List[str]:
49
+ return extract_local_file_ids(str_) + extract_remote_file_ids(str_)
50
+
51
+
52
+ def extract_local_file_ids(str_: str) -> List[str]:
53
+ return _compiled_local_file_id_pattern.findall(str_)
54
+
55
+
56
+ def extract_remote_file_ids(str_: str) -> List[str]:
57
+ return _compiled_remote_file_id_pattern.findall(str_)
erniebot-agent/erniebot_agent/file_io/remote_file.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ import asyncio
17
+ import functools
18
+ import pathlib
19
+ import time
20
+ import uuid
21
+ from typing import ClassVar, Dict, List
22
+
23
+ import anyio
24
+ from baidubce.auth.bce_credentials import BceCredentials
25
+ from baidubce.bce_client_configuration import BceClientConfiguration
26
+ from baidubce.services.bos.bos_client import BosClient
27
+ from erniebot_agent.file_io.base import File
28
+ from erniebot_agent.file_io.protocol import (
29
+ build_remote_file_id_from_uuid,
30
+ is_remote_file_id,
31
+ )
32
+
33
+
34
+ class RemoteFile(File):
35
+ def __init__(self, id: str, filename: str, created_at: int, client: "RemoteFileClient") -> None:
36
+ if not is_remote_file_id(id):
37
+ raise ValueError("Invalid file ID: {id}")
38
+ super().__init__(id=id, filename=filename, created_at=created_at)
39
+ self._client = client
40
+
41
+ async def read_contents(self) -> bytes:
42
+ file_contents = await self._client.retrieve_file_contents(self.id)
43
+ return file_contents
44
+
45
+ async def delete(self) -> None:
46
+ await self._client.delete_file(self.id)
47
+
48
+
49
+ class RemoteFileClient(metaclass=abc.ABCMeta):
50
+ @abc.abstractmethod
51
+ async def upload_file(self, file_path: pathlib.Path) -> RemoteFile:
52
+ raise NotImplementedError
53
+
54
+ @abc.abstractmethod
55
+ async def retrieve_file(self, file_id: str) -> RemoteFile:
56
+ raise NotImplementedError
57
+
58
+ @abc.abstractmethod
59
+ async def retrieve_file_contents(self, file_id: str) -> bytes:
60
+ raise NotImplementedError
61
+
62
+ @abc.abstractmethod
63
+ async def list_files(self) -> List[RemoteFile]:
64
+ raise NotImplementedError
65
+
66
+ @abc.abstractmethod
67
+ async def delete_file(self, file_id: str) -> None:
68
+ raise NotImplementedError
69
+
70
+
71
+ class BOSFileClient(RemoteFileClient):
72
+ _ENDPOINT: ClassVar[str] = "bj.bcebos.com"
73
+
74
+ def __init__(self, ak: str, sk: str, bucket_name: str, prefix: str) -> None:
75
+ super().__init__()
76
+ self.bucket_name = bucket_name
77
+ self.prefix = prefix
78
+ config = BceClientConfiguration(credentials=BceCredentials(ak, sk), endpoint=self._ENDPOINT)
79
+ self._bos_client = BosClient(config=config)
80
+
81
+ async def upload_file(self, file_path: pathlib.Path) -> RemoteFile:
82
+ file_id = self._generate_file_id()
83
+ filename = file_path.name
84
+ created_at = int(time.time())
85
+ user_metadata: Dict[str, str] = {"id": file_id, "filename": filename, "created_at": str(created_at)}
86
+ async with await anyio.open_file(file_path, mode="rb") as f:
87
+ data = await f.read()
88
+ loop = asyncio.get_running_loop()
89
+ await loop.run_in_executor(
90
+ None,
91
+ functools.partial(
92
+ self._bos_client.put_object_from_string,
93
+ bucket=self.bucket_name,
94
+ key=self._get_key(file_id),
95
+ data=data,
96
+ user_metadata=user_metadata,
97
+ ),
98
+ )
99
+ return RemoteFile(
100
+ id=file_id,
101
+ filename=filename,
102
+ created_at=created_at,
103
+ client=self,
104
+ )
105
+
106
+ async def retrieve_file(self, file_id: str) -> RemoteFile:
107
+ loop = asyncio.get_running_loop()
108
+ response = await loop.run_in_executor(
109
+ None,
110
+ functools.partial(
111
+ self._bos_client.get_object_meta_data, self.bucket_name, self._get_key(file_id)
112
+ ),
113
+ )
114
+ user_metadata = {
115
+ "id": response.metadata.bce_meta_id,
116
+ "filename": response.metadata.bce_meta_filename,
117
+ "created_at": int(response.metadata.bce_meta_created_at),
118
+ }
119
+ if file_id != user_metadata["id"]:
120
+ raise RuntimeError("`file_id` is not the same as the one in metadata.")
121
+
122
+ return RemoteFile(
123
+ id=user_metadata["id"],
124
+ filename=user_metadata["filename"],
125
+ created_at=user_metadata["created_at"],
126
+ client=self,
127
+ )
128
+
129
+ async def retrieve_file_contents(self, file_id: str) -> bytes:
130
+ loop = asyncio.get_running_loop()
131
+ result = await loop.run_in_executor(
132
+ None,
133
+ functools.partial(
134
+ self._bos_client.get_object_as_string, self.bucket_name, self._get_key(file_id)
135
+ ),
136
+ )
137
+ return result
138
+
139
+ async def list_files(self) -> List[RemoteFile]:
140
+ raise RuntimeError(f"`{self.__class__.__name__}.list_files` is not supported.")
141
+
142
+ async def delete_file(self, file_id: str) -> None:
143
+ loop = asyncio.get_running_loop()
144
+ await loop.run_in_executor(
145
+ None, functools.partial(self._bos_client.delete_object, self.bucket_name, self._get_key(file_id))
146
+ )
147
+
148
+ def _get_key(self, file_id: str) -> str:
149
+ return self.prefix + file_id
150
+
151
+ @staticmethod
152
+ def _generate_file_id() -> str:
153
+ return build_remote_file_id_from_uuid(str(uuid.uuid1()))
erniebot-agent/erniebot_agent/memory/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import Memory
16
+ from .limit_token_memory import LimitTokensMemory
17
+ from .sliding_window_memory import SlidingWindowMemory
18
+ from .whole_memory import WholeMemory
erniebot-agent/erniebot_agent/memory/base.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Union
16
+
17
+ from erniebot_agent.messages import AIMessage, Message, SystemMessage
18
+
19
+
20
+ class MessageManager:
21
+ """
22
+ Messages Manager.
23
+ """
24
+
25
+ def __init__(self) -> None:
26
+ self.messages: List[Message] = []
27
+ self._system_message: Union[SystemMessage, None] = None
28
+
29
+ @property
30
+ def system_message(self) -> Optional[Message]:
31
+ """
32
+ The message manager have only one system message.
33
+
34
+ return: Message or None
35
+ """
36
+ return self._system_message
37
+
38
+ @system_message.setter
39
+ def system_message(self, message: SystemMessage) -> None:
40
+ if self._system_message is not None:
41
+ Warning("system message has been set, the previous one will be replaced")
42
+
43
+ self._system_message = message
44
+
45
+ def add_messages(self, messages: List[Message]) -> None:
46
+ self.messages.extend(messages)
47
+
48
+ def add_message(self, message: Message) -> None:
49
+ if isinstance(message, SystemMessage):
50
+ self.system_message = message
51
+ else:
52
+ self.messages.append(message)
53
+
54
+ def pop_message(self) -> Message:
55
+ return self.messages.pop(0)
56
+
57
+ def clear_messages(self) -> None:
58
+ self.messages = []
59
+
60
+ def update_last_message_token_count(self, token_count: int):
61
+ if token_count == 0:
62
+ self.messages[-1].token_count = len(self.messages[-1].content)
63
+ else:
64
+ self.messages[-1].token_count = token_count
65
+
66
+ def retrieve_messages(self) -> List[Message]:
67
+ return self.messages
68
+
69
+
70
+ class Memory:
71
+ """The base class of memory"""
72
+
73
+ def __init__(self):
74
+ self.msg_manager = MessageManager()
75
+
76
+ def add_messages(self, messages: List[Message]):
77
+ for message in messages:
78
+ self.add_message(message)
79
+
80
+ def add_message(self, message: Message):
81
+ if isinstance(message, AIMessage):
82
+ self.msg_manager.update_last_message_token_count(message.query_tokens_count)
83
+ self.msg_manager.add_message(message)
84
+
85
+ def get_messages(self) -> List[Message]:
86
+ return self.msg_manager.retrieve_messages()
87
+
88
+ def get_system_message(self) -> SystemMessage:
89
+ return self.msg_manager.system_message
90
+
91
+ def clear_chat_history(self):
92
+ self.msg_manager.clear_messages()
93
+
94
+
95
+ class WholeMemory(Memory):
96
+ """The memory include all the messages"""
97
+
98
+ def __init__(self):
99
+ super().__init__()
erniebot-agent/erniebot_agent/memory/limit_token_memory.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from erniebot_agent.memory import Memory
17
+ from erniebot_agent.messages import AIMessage, Message
18
+
19
+
20
+ class LimitTokensMemory(Memory):
21
+ """This class controls max tokens less than max_token_limit.
22
+ If tokens >= max_token_limit, pop message from memory.
23
+ """
24
+
25
+ def __init__(self, max_token_limit=6000):
26
+ super().__init__()
27
+ self.max_token_limit = max_token_limit
28
+ self.mem_token_count = 0
29
+
30
+ assert (
31
+ max_token_limit is None
32
+ ) or max_token_limit > 0, "max_token_limit should be None or positive integer, \
33
+ but got {max_token_limit}".format(
34
+ max_token_limit=max_token_limit
35
+ )
36
+
37
+ def add_message(self, message: Message):
38
+ super().add_message(message)
39
+ # TODO(shiyutang): 仅在添加AIMessage时截断会导致HumanMessage传入到LLM时可能长度超限
40
+ # 最优方案为每条message产生时确定token_count,从而在每次加入message时都进行prune_message
41
+ if isinstance(message, AIMessage):
42
+ self.prune_message()
43
+
44
+ def prune_message(self):
45
+ self.mem_token_count += self.msg_manager.messages[-1].token_count
46
+ self.mem_token_count += self.msg_manager.messages[-2].token_count # add human message token length
47
+ if self.max_token_limit is not None:
48
+ while self.mem_token_count > self.max_token_limit:
49
+ deleted_message = self.msg_manager.pop_message()
50
+ self.mem_token_count -= deleted_message.token_count
51
+ else:
52
+ # if delete all
53
+ if len(self.get_messages()) == 0:
54
+ raise RuntimeError(
55
+ "The messsage is now empty. \
56
+ It indicates {} which takes up {} tokens and exeeded {} tokens.".format(
57
+ deleted_message, len(deleted_message.content), self.max_token_limit
58
+ )
59
+ )
erniebot-agent/erniebot_agent/memory/sliding_window_memory.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from erniebot_agent.memory import Memory
16
+ from erniebot_agent.messages import Message
17
+
18
+
19
+ class SlidingWindowMemory(Memory):
20
+ """This class controls max number of messages."""
21
+
22
+ def __init__(self, max_num_message: int):
23
+ super().__init__()
24
+ self.max_num_message = max_num_message
25
+
26
+ assert (isinstance(max_num_message, int)) and (
27
+ max_num_message > 0
28
+ ), "max_num_message should be positive integer, but got {max_token_limit}".format(
29
+ max_token_limit=max_num_message
30
+ )
31
+
32
+ def add_message(self, message: Message):
33
+ super().add_message(message=message)
34
+ self.prune_message()
35
+
36
+ def prune_message(self):
37
+ while len(self.get_messages()) > self.max_num_message:
38
+ self.msg_manager.pop_message()
39
+ # `messages` must have an odd number of elements.
40
+ if len(self.get_messages()) % 2 == 0:
41
+ self.msg_manager.pop_message()
erniebot-agent/erniebot_agent/memory/whole_memory.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from erniebot_agent.memory.base import Memory
16
+
17
+
18
+ class WholeMemory(Memory):
19
+ """"""
erniebot-agent/erniebot_agent/messages.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Licensed under the Apache License, Version 2.0 (the "License"
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License
13
+
14
+ from typing import Dict, List, Optional, TypedDict
15
+
16
+ import erniebot.utils.token_helper as token_helper
17
+
18
+
19
+ class Message:
20
+ """The base class of a message."""
21
+
22
+ def __init__(self, role: str, content: str, token_count: Optional[int] = None):
23
+ self.role = role
24
+ self.content = content
25
+ self._token_count = token_count
26
+ self._param_names = ["role", "content"]
27
+
28
+ @property
29
+ def token_count(self):
30
+ """Get the number of tokens of the message."""
31
+ if self._token_count is None:
32
+ raise AttributeError("The token count of the message has not been set.")
33
+ return self._token_count
34
+
35
+ @token_count.setter
36
+ def token_count(self, token_count: int):
37
+ """Set the number of tokens of the message."""
38
+ if self._token_count is not None:
39
+ raise AttributeError("The token count of the message can only be set once.")
40
+ self._token_count = token_count
41
+
42
+ def to_dict(self) -> Dict[str, str]:
43
+ res = {}
44
+ for name in self._param_names:
45
+ res[name] = getattr(self, name)
46
+ return res
47
+
48
+ def __str__(self) -> str:
49
+ return f"<{self._get_attrs_str()}>"
50
+
51
+ def __repr__(self):
52
+ return f"<{self.__class__.__name__} {self._get_attrs_str()}>"
53
+
54
+ def _get_attrs_str(self) -> str:
55
+ parts: List[str] = []
56
+ for name in self._param_names:
57
+ value = getattr(self, name)
58
+ if value is not None and value != "":
59
+ parts.append(f"{name}: {repr(value)}")
60
+ if self._token_count is not None:
61
+ parts.append(f"token_count: {self._token_count}")
62
+ return ", ".join(parts)
63
+
64
+
65
+ class SystemMessage(Message):
66
+ """A message from a human to set system information."""
67
+
68
+ def __init__(self, content: str):
69
+ super().__init__(role="system", content=content, token_count=len(content))
70
+
71
+
72
+ class HumanMessage(Message):
73
+ """A message from a human."""
74
+
75
+ def __init__(self, content: str):
76
+ super().__init__(role="user", content=content)
77
+
78
+
79
+ class FunctionCall(TypedDict):
80
+ name: str
81
+ thoughts: str
82
+ arguments: str
83
+
84
+
85
+ class TokenUsage(TypedDict):
86
+ prompt_tokens: int
87
+ completion_tokens: int
88
+
89
+
90
+ class AIMessage(Message):
91
+ """A message from an assistant."""
92
+
93
+ def __init__(
94
+ self,
95
+ content: str,
96
+ function_call: Optional[FunctionCall],
97
+ token_usage: Optional[TokenUsage] = None,
98
+ ):
99
+ if token_usage is None:
100
+ prompt_tokens = 0
101
+ completion_tokens = token_helper.approx_num_tokens(content)
102
+ else:
103
+ prompt_tokens, completion_tokens = self._parse_token_count(token_usage)
104
+ super().__init__(role="assistant", content=content, token_count=completion_tokens)
105
+ self.function_call = function_call
106
+ self.query_tokens_count = prompt_tokens
107
+ self._param_names = ["role", "content", "function_call"]
108
+
109
+ def _parse_token_count(self, token_usage: TokenUsage):
110
+ """Parse the token count information from LLM."""
111
+ return token_usage["prompt_tokens"], token_usage["completion_tokens"]
112
+
113
+
114
+ class FunctionMessage(Message):
115
+ """A message from a human, containing the result of a function call."""
116
+
117
+ def __init__(self, name: str, content: str):
118
+ super().__init__(role="function", content=content)
119
+ self.name = name
120
+ self._param_names = ["role", "name", "content"]
121
+
122
+
123
+ class AIMessageChunk(AIMessage):
124
+ """A message chunk from an assistant."""
erniebot-agent/erniebot_agent/prompt/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import BasePromptTemplate
16
+ from .prompt_template import PromptTemplate
erniebot-agent/erniebot_agent/prompt/base.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from abc import ABC, abstractmethod
15
+ from typing import List, Optional
16
+
17
+
18
+ class BasePromptTemplate(ABC):
19
+ def __init__(self, input_variables: Optional[List[str]]):
20
+ self.input_variables: Optional[List[str]] = input_variables
21
+
22
+ @abstractmethod
23
+ def format(self, **kwargs):
24
+ raise NotImplementedError
25
+
26
+ @abstractmethod
27
+ def format_as_message(self, message_class, **kwargs):
28
+ raise NotImplementedError
erniebot-agent/erniebot_agent/prompt/prompt_template.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Optional
16
+
17
+ from erniebot_agent.messages import HumanMessage
18
+ from erniebot_agent.prompt import BasePromptTemplate
19
+ from jinja2 import Environment, meta
20
+
21
+
22
+ def jinja2_formatter(template: str, **kwargs: Any) -> str:
23
+ """Format a template using jinja2."""
24
+ try:
25
+ from jinja2 import Template
26
+ except ImportError:
27
+ raise ImportError(
28
+ "jinja2 not installed, which is needed to use the jinja2_formatter. "
29
+ "Please install it with `pip install jinja2`."
30
+ )
31
+
32
+ return Template(template).render(**kwargs)
33
+
34
+
35
+ class PromptTemplate(BasePromptTemplate):
36
+ """format the prompt for llm input."""
37
+
38
+ def __init__(
39
+ self, template: str, name: Optional[str] = None, input_variables: Optional[List[str]] = None
40
+ ):
41
+ super().__init__(input_variables)
42
+ self.name = name
43
+ self.template = template
44
+ self.validate_template = True if input_variables is not None else False # todo: 验证模板是否正确
45
+
46
+ def format(self, **kwargs) -> str:
47
+ if self.validate_template:
48
+ error = self._validate_template()
49
+ if error:
50
+ raise KeyError("The input_variables of PromptTemplate and template are not match! " + error)
51
+ return jinja2_formatter(self.template, **kwargs)
52
+
53
+ def _validate_template(self):
54
+ """
55
+ Validate that the input variables are valid for the template.
56
+
57
+ Args:
58
+ template: The template string.
59
+ input_variables: The input variables.
60
+ """
61
+ input_variables_set = set(self.input_variables)
62
+ env = Environment()
63
+ ast = env.parse(self.template)
64
+ valid_variables = meta.find_undeclared_variables(ast)
65
+
66
+ missing_variables = valid_variables - input_variables_set
67
+ extra_variables = input_variables_set - valid_variables
68
+
69
+ Error_message = ""
70
+ if missing_variables:
71
+ Error_message += f"The missing input variables: {missing_variables} "
72
+
73
+ if extra_variables:
74
+ Error_message += f"The extra input variables: {extra_variables}"
75
+
76
+ return Error_message
77
+
78
+ def format_as_message(self, **kwargs):
79
+ prompt = self.format(**kwargs)
80
+ return HumanMessage(content=prompt)
erniebot-agent/erniebot_agent/retrieval/__init__.py ADDED
File without changes
erniebot-agent/erniebot_agent/retrieval/baizhong_search.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import requests
7
+ from erniebot_agent.utils.exception import BaizhongError
8
+ from erniebot_agent.utils.logging import logger
9
+ from tqdm import tqdm
10
+
11
+ from .document import Document
12
+
13
+
14
+ class BaizhongSearch:
15
+ def __init__(
16
+ self,
17
+ base_url: str,
18
+ project_name: Optional[str] = None,
19
+ remark: Optional[str] = None,
20
+ project_id: Optional[int] = None,
21
+ max_seq_length: int = 512,
22
+ ) -> None:
23
+ self.base_url = base_url
24
+ self.max_seq_length = max_seq_length
25
+ if project_id is not None:
26
+ logger.info(f"Loading existing project with `project_id={project_id}`")
27
+ self.project_id = project_id
28
+ elif project_name is not None:
29
+ logger.info("Creating new project and schema")
30
+ self.index = self.create_project(project_name, remark)
31
+ logger.info("Project creation succeeded")
32
+ self.project_id = self.index["result"]["projectId"]
33
+ self.create_schema()
34
+ logger.info("Schema creation succeeded")
35
+ else:
36
+ raise BaizhongError("You must provide either a `project_name` or a `project_id`.")
37
+
38
+ def create_project(self, project_name: str, remark: Optional[str] = None):
39
+ """
40
+ Create a project using the Baizhong API.
41
+
42
+ Returns:
43
+ dict: A dictionary containing information about the created project.
44
+
45
+ Raises:
46
+ BaizhongError: If the API request fails, this exception is raised with details about the error.
47
+ """
48
+ json_data = {
49
+ "projectName": project_name,
50
+ "remark": remark,
51
+ }
52
+ res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project/add", json=json_data)
53
+ if res.status_code == 200:
54
+ result = res.json()
55
+ if result["errCode"] != 0:
56
+ raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
57
+ return result
58
+ else:
59
+ raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
60
+
61
+ def create_schema(self):
62
+ """
63
+ Create a schema for a project using the Baizhong API.
64
+
65
+ Returns:
66
+ dict: A dictionary containing information about the created schema.
67
+
68
+ Raises:
69
+ BaizhongError: If the API request fails, this exception is raised with details about the error.
70
+ """
71
+ json_data = {
72
+ "projectId": self.project_id,
73
+ "schemaJson": {
74
+ "paraSize": self.max_seq_length,
75
+ "dataSegmentationMod": "neisou",
76
+ "storeType": "ElasticSearch",
77
+ "properties": {
78
+ "title": {"type": "text", "shortindex": True},
79
+ "content_se": {"type": "text", "longindex": True},
80
+ },
81
+ },
82
+ }
83
+ res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project-schema/create", json=json_data)
84
+ if res.status_code == 200:
85
+ result = res.json()
86
+ if result["errCode"] != 0:
87
+ raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
88
+ return res.json()
89
+ else:
90
+ raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
91
+
92
+ def update_schema(
93
+ self,
94
+ ):
95
+ """
96
+ Update the schema for a project using the Baizhong API.
97
+
98
+ Returns:
99
+ dict: A dictionary containing information about the updated schema.
100
+
101
+ Raises:
102
+ BaizhongError: If the API request fails, this exception is raised with details about the error.
103
+ """
104
+ json_data = {
105
+ "projectId": self.project_id,
106
+ "schemaJson": {
107
+ "paraSize": self.max_seq_length,
108
+ "dataSegmentationMod": "neisou",
109
+ "storeType": "ElasticSearch",
110
+ "properties": {
111
+ "title": {"type": "text", "shortindex": True},
112
+ "content_se": {"type": "text", "longindex": True},
113
+ },
114
+ },
115
+ }
116
+ res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project-schema/update", json=json_data)
117
+ status_code = res.status_code
118
+ if status_code == 200:
119
+ result = res.json()
120
+ if result["errCode"] != 0:
121
+ raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
122
+ return result
123
+ else:
124
+ raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
125
+
126
+ def search(self, query: str, top_k: int = 10, filters: Optional[Dict[str, Any]] = None):
127
+ """
128
+ Perform a search using the Baizhong common search API.
129
+
130
+ Args:
131
+ query (str): The search query.
132
+ top_k (int, optional): The number of top results to retrieve (default is 10).
133
+ filters (Optional[Dict[str, Any]], optional): Additional filters to apply to the search query
134
+ (default is None).
135
+
136
+ Returns:
137
+ List[Dict[str, Any]]: A list of dictionaries containing search results.
138
+
139
+ Raises:
140
+ BaizhongError: If the API request fails, this exception is raised with details about the error.
141
+ """
142
+ json_data = {
143
+ "query": query,
144
+ "projectId": self.project_id,
145
+ "size": top_k,
146
+ }
147
+ if filters is not None:
148
+ filterConditions = {"filterConditions": {"bool": {"filter": {"match": filters}}}}
149
+ json_data.update(filterConditions)
150
+ res = requests.post(f"{self.base_url}/baizhong/common-search/v2/search", json=json_data)
151
+ if res.status_code == 200:
152
+ result = res.json()
153
+ if result["errCode"] != 0:
154
+ raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
155
+ list_data = []
156
+ for item in result["hits"]:
157
+ content = item["_source"]["doc"]
158
+ content = base64.b64decode(content).decode("utf-8")
159
+ json_data = json.loads(content)
160
+ list_data.append(json_data)
161
+ return list_data
162
+ else:
163
+ raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
164
+
165
+ def add_documents(self, documents: List[Document], batch_size: int = 1, thread_count: int = 1):
166
+ """
167
+ Add a batch of documents to the Baizhong system using multi-threading.
168
+
169
+ Args:
170
+ documents (List[Document]): A list of Document objects to be added.
171
+ batch_size (int, optional): The size of each batch of documents (defaults to 1).
172
+ thread_count (int, optional): The number of threads to use for concurrent document addition
173
+ (defaults to 1).
174
+
175
+ Returns:
176
+ List[Union[None, Exception]]: A list of results from the document addition process.
177
+
178
+ Note:
179
+ This function uses multi-threading to improve the efficiency of adding a large number of
180
+ documents.
181
+
182
+ """
183
+ if type(documents[0]) == Document:
184
+ list_dicts = [item.to_dict() for item in documents]
185
+ all_data = []
186
+ for i in tqdm(range(0, len(list_dicts), batch_size)):
187
+ batch_data = list_dicts[i : i + batch_size]
188
+ all_data.append(batch_data)
189
+ with ThreadPoolExecutor(max_workers=thread_count) as executor:
190
+ res = executor.map(self._add_documents, all_data)
191
+ return list(res)
192
+
193
+ def get_document_by_id(self, doc_id):
194
+ """
195
+ Retrieve a document from the Baizhong system by its ID.
196
+
197
+ Args:
198
+ doc_id: The ID of the document to retrieve.
199
+
200
+ Returns:
201
+ dict: A dictionary containing information about the retrieved document.
202
+
203
+ Raises:
204
+ BaizhongError: If the API request fails, this exception is raised with details about the error.
205
+ """
206
+ json_data = {"projectId": self.project_id, "followIndexFlag": True, "dataBody": [doc_id]}
207
+ res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/get", json=json_data)
208
+ if res.status_code == 200:
209
+ result = res.json()
210
+ if result["errCode"] != 0:
211
+ raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
212
+ return result
213
+ else:
214
+ raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
215
+
216
+ def delete_documents(
217
+ self,
218
+ ids: Optional[List[str]] = None,
219
+ ):
220
+ """
221
+ Delete documents from the Baizhong system.
222
+
223
+ Args:
224
+ ids (Optional[List[str]], optional): A list of document IDs to delete. If not provided,
225
+ all documents will be deleted.
226
+
227
+ Returns:
228
+ dict: A dictionary containing information about the deletion process.
229
+
230
+ Raises:
231
+ NotImplementedError: If the deletion of all documents is attempted, this exception is raised
232
+ as it is not yet implemented.
233
+ BaizhongError: If the API request fails, this exception is raised with details about the error.
234
+ """
235
+ json_data: Dict[str, Any] = {"projectId": self.project_id, "followIndexFlag": True}
236
+ if ids is not None:
237
+ json_data["dataBody"] = ids
238
+ else:
239
+ # TODO: delete all documents
240
+ raise NotImplementedError
241
+ res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/delete", json=json_data)
242
+ if res.status_code == 200:
243
+ result = res.json()
244
+ if result["errCode"] != 0:
245
+ raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
246
+ return result
247
+ else:
248
+ raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
249
+
250
+ def _add_documents(self, documents: List[Dict[str, Any]]):
251
+ """
252
+ Internal method to add a batch of documents to the Baizhong system.
253
+
254
+ Args:
255
+ documents (List[Dict[str, Any]]): A list of dictionaries representing documents to be added.
256
+
257
+ Returns:
258
+ dict: A dictionary containing information about the document addition process.
259
+
260
+ Raises:
261
+ BaizhongError: If the API request fails, this exception is raised with details about the error.
262
+ """
263
+ json_data = {"projectId": self.project_id, "followIndexFlag": True, "dataBody": documents}
264
+ res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/add", json=json_data)
265
+ if res.status_code == 200:
266
+ result = res.json()
267
+ if result["errCode"] != 0:
268
+ raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
269
+ return result
270
+ else:
271
+ # TODO(wugaosheng): retry 3 times
272
+ raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
273
+
274
+ @classmethod
275
+ def delete_project(cls, project_id: int):
276
+ """
277
+ Class method to delete a project using the Baizhong API.
278
+
279
+ Args:
280
+ project_id (int): The ID of the project to be deleted.
281
+
282
+ Returns:
283
+ dict: A dictionary containing information about the deletion process.
284
+
285
+ Raises:
286
+ BaizhongError: If the API request fails, this exception is raised with details about the error.
287
+ """
288
+ json_data = {"projectId": project_id}
289
+ res = requests.post(f"{cls.base_url}/baizhong/web-api/v2/project/delete", json=json_data)
290
+ if res.status_code == 200:
291
+ result = res.json()
292
+ if result["errCode"] != 0:
293
+ raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
294
+ return res.json()
295
+ else:
296
+ raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
erniebot-agent/erniebot_agent/retrieval/document.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ from pydantic import BaseConfig
6
+ from pydantic.dataclasses import dataclass
7
+
8
+ BaseConfig.arbitrary_types_allowed = True
9
+
10
+
11
+ @dataclass
12
+ class Document:
13
+ id: str
14
+ title: str
15
+ content_se: str
16
+ meta: Dict[str, Any]
17
+
18
+ def __init__(
19
+ self,
20
+ content_se: str,
21
+ title: str,
22
+ id: Optional[str] = None,
23
+ meta: Optional[Dict[str, Any]] = None,
24
+ ):
25
+ self.content_se = content_se
26
+ self.title = title
27
+ self.id = id or self._get_id()
28
+ self.meta = meta or {}
29
+
30
+ @classmethod
31
+ def _get_id(cls, content_se=None) -> str:
32
+ md5_bytes = content_se.encode(encoding="UTF-8")
33
+ md5_string = hashlib.md5(md5_bytes).hexdigest()
34
+ return md5_string
35
+
36
+ def to_dict(self, field_map: Optional[Dict[str, Any]] = None) -> Dict:
37
+ """
38
+ Convert Document to dict. An optional field_map can be supplied to
39
+ change the names of the keys in the resulting dict.
40
+ This way you can work with standardized Document objects in erniebot-agent,
41
+ but adjust the format that they are serialized / stored in other places
42
+ (e.g. elasticsearch)
43
+ Example:
44
+
45
+ ```python
46
+ doc = Document(content="some text", content_type="text")
47
+ doc.to_dict(field_map={"custom_content_field": "content"})
48
+
49
+ # Returns {"custom_content_field": "some text"}
50
+ ```
51
+
52
+ :param field_map: Dict with keys being the custom target keys and values
53
+ being the standard Document attributes
54
+ :return: dict with content of the Document
55
+ """
56
+ if not field_map:
57
+ field_map = {}
58
+
59
+ inv_field_map = {v: k for k, v in field_map.items()}
60
+ _doc: Dict[str, str] = {}
61
+ for k, v in self.__dict__.items():
62
+ # Exclude internal fields (Pydantic, ...) fields from the conversion process
63
+ if k.startswith("__"):
64
+ continue
65
+ k = k if k not in inv_field_map else inv_field_map[k]
66
+ _doc[k] = v
67
+ return _doc
68
+
69
+ @classmethod
70
+ def from_dict(cls, dict: Dict[str, Any], field_map: Optional[Dict[str, Any]] = None):
71
+ """
72
+ Create Document from dict. An optional `field_map` parameter
73
+ can be supplied to adjust for custom names of the keys in the
74
+ input dict. This way you can work with standardized Document
75
+ objects in erniebot-agent, but adjust the format that
76
+ they are serialized / stored in other places (e.g. elasticsearch).
77
+
78
+ Example:
79
+
80
+ ```python
81
+ my_dict = {"custom_content_field": "some text", "content_type": "text"}
82
+ Document.from_dict(my_dict, field_map={"custom_content_field": "content"})
83
+ ```
84
+
85
+ :param field_map: Dict with keys being the custom target keys and values
86
+ being the standard Document attributes
87
+ :return: A Document object
88
+ """
89
+ if not field_map:
90
+ field_map = {}
91
+
92
+ _doc = dict.copy()
93
+ init_args = ["content_se", "meta", "id", "title"]
94
+ if "meta" not in _doc.keys():
95
+ _doc["meta"] = {}
96
+ if "id" not in _doc.keys():
97
+ _doc["id"] = cls._get_id(_doc["content_se"])
98
+ # copy additional fields into "meta"
99
+ for k, v in _doc.items():
100
+ # Exclude internal fields (Pydantic, ...) fields from the conversion process
101
+ if k.startswith("__"):
102
+ continue
103
+ if k not in init_args and k not in field_map:
104
+ _doc["meta"][k] = v
105
+ # remove additional fields from top level
106
+ _new_doc = {}
107
+ for k, v in _doc.items():
108
+ if k in init_args:
109
+ _new_doc[k] = v
110
+ elif k in field_map:
111
+ k = field_map[k]
112
+ _new_doc[k] = v
113
+ return cls(**_new_doc)
114
+
115
+ @classmethod
116
+ def from_json(cls, data: Union[str, Dict[str, Any]], field_map: Optional[Dict[str, Any]] = None):
117
+ if not field_map:
118
+ field_map = {}
119
+ if isinstance(data, str):
120
+ dict_data = json.loads(data)
121
+ else:
122
+ dict_data = data
123
+ return cls.from_dict(dict_data, field_map=field_map)
erniebot-agent/erniebot_agent/tools/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .image_generation_tool import ImageGenerationTool
erniebot-agent/erniebot_agent/tools/baizhong_tool.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, List, Optional, Type
4
+
5
+ from erniebot_agent.messages import AIMessage, HumanMessage
6
+ from erniebot_agent.tools.schema import ToolParameterView
7
+ from pydantic import Field
8
+
9
+ from .base import Tool
10
+
11
+
12
+ class BaizhongSearchToolInputView(ToolParameterView):
13
+ query: str = Field(description="Query")
14
+ top_k: int = Field(description="Number of results to return")
15
+
16
+
17
+ class SearchResponseDocument(ToolParameterView):
18
+ id: str = Field(description="text id")
19
+ title: str = Field(description="title")
20
+ document: str = Field(description="content")
21
+
22
+
23
+ class BaizhongSearchToolOutputView(ToolParameterView):
24
+ documents: List[SearchResponseDocument] = Field(description="research results")
25
+
26
+
27
+ class BaizhongSearchTool(Tool):
28
+ description: str = "aurora search tool"
29
+ input_type: Type[ToolParameterView] = BaizhongSearchToolInputView
30
+ ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView
31
+
32
+ def __init__(self, description, db, input_type=None, output_type=None, examples=None) -> None:
33
+ super().__init__()
34
+ self.db = db
35
+ self.description = description
36
+ if input_type is not None:
37
+ self.input_type = input_type
38
+ if output_type is not None:
39
+ self.ouptut_type = output_type
40
+ if examples is not None:
41
+ self.few_shot_examples = examples
42
+
43
+ async def __call__(self, query: str, top_k: int = 10, filters: Optional[dict[str, Any]] = None):
44
+ res = self.db.search(query, top_k, filters)
45
+ return res
46
+
47
+ @property
48
+ def examples(
49
+ self,
50
+ ) -> List[Any]:
51
+ few_shot_objects: List[Any] = []
52
+ for item in self.few_shot_examples:
53
+ few_shot_objects.append(HumanMessage(item["user"]))
54
+ few_shot_objects.append(
55
+ AIMessage(
56
+ "",
57
+ function_call={
58
+ "name": self.tool_name,
59
+ "thoughts": item["thoughts"],
60
+ "arguments": item["arguments"],
61
+ },
62
+ )
63
+ )
64
+
65
+ return few_shot_objects
erniebot-agent/erniebot_agent/tools/base.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import os
19
+ import tempfile
20
+ from abc import ABC, abstractmethod
21
+ from dataclasses import asdict, dataclass, field
22
+ from typing import Any, Dict, List, Optional, Type
23
+
24
+ import requests
25
+ from erniebot_agent.messages import AIMessage, FunctionCall, HumanMessage, Message
26
+ from erniebot_agent.tools.schema import (
27
+ Endpoint,
28
+ EndpointInfo,
29
+ RemoteToolView,
30
+ ToolParameterView,
31
+ scrub_dict,
32
+ )
33
+ from erniebot_agent.utils.http import url_file_exists
34
+ from erniebot_agent.utils.logging import logger
35
+ from openapi_spec_validator import validate
36
+ from openapi_spec_validator.readers import read_from_filename
37
+ from yaml import safe_dump
38
+
39
+ import erniebot
40
+
41
+
42
+ def validate_openapi_yaml(yaml_file: str) -> bool:
43
+ """do validation on the yaml file
44
+
45
+ Args:
46
+ yaml_file (str): the path of yaml file
47
+
48
+ Returns:
49
+ bool: whether yaml file is valid
50
+ """
51
+ yaml_dict = read_from_filename(yaml_file)[0]
52
+ try:
53
+ validate(yaml_dict)
54
+ return True
55
+ except Exception as e: # type: ignore
56
+ logger.error(e)
57
+ return False
58
+
59
+
60
+ class BaseTool(ABC):
61
+ @abstractmethod
62
+ async def __call__(self, *args: Any, **kwds: Any) -> Any:
63
+ raise NotImplementedError
64
+
65
+ @abstractmethod
66
+ def function_call_schema(self) -> dict:
67
+ raise NotImplementedError
68
+
69
+
70
+ class Tool(BaseTool, ABC):
71
+ description: str
72
+ name: Optional[str] = None
73
+ input_type: Optional[Type[ToolParameterView]] = None
74
+ ouptut_type: Optional[Type[ToolParameterView]] = None
75
+
76
+ def __str__(self) -> str:
77
+ return "<name: {0}, description: {1}>".format(self.name, self.description)
78
+
79
+ def __repr__(self):
80
+ return self.__str__()
81
+
82
+ @property
83
+ def tool_name(self):
84
+ return self.name or self.__class__.__name__
85
+
86
+ @abstractmethod
87
+ async def __call__(self, *args: Any, **kwds: Any) -> Dict[str, Any]:
88
+ """the body of tools
89
+
90
+ Returns:
91
+ Any:
92
+ """
93
+ raise NotImplementedError
94
+
95
+ def function_call_schema(self) -> dict:
96
+ inputs = {
97
+ "name": self.tool_name,
98
+ "description": self.description,
99
+ "examples": [example.to_dict() for example in self.examples],
100
+ }
101
+ if self.input_type is not None:
102
+ inputs["parameters"] = self.input_type.function_call_schema()
103
+ if self.ouptut_type is not None:
104
+ inputs["responses"] = self.ouptut_type.function_call_schema()
105
+
106
+ return scrub_dict(inputs) or {}
107
+
108
+ @property
109
+ def examples(self) -> List[Message]:
110
+ return []
111
+
112
+
113
+ class RemoteTool(BaseTool):
114
+ def __init__(
115
+ self,
116
+ tool_view: RemoteToolView,
117
+ server_url: str,
118
+ headers: dict,
119
+ examples: Optional[List[Message]] = None,
120
+ ) -> None:
121
+ self.tool_view = tool_view
122
+ self.server_url = server_url
123
+ self.headers = headers
124
+ self.examples = examples
125
+
126
+ def __str__(self) -> str:
127
+ return "<name: {0}, server_url: {1}, description: {2}>".format(
128
+ self.tool_name, self.server_url, self.tool_view.description
129
+ )
130
+
131
+ def __repr__(self):
132
+ return self.__str__()
133
+
134
+ @property
135
+ def tool_name(self):
136
+ return self.tool_view.name
137
+
138
+ async def __call__(self, **tool_arguments: Dict[str, Any]) -> Any:
139
+ url = self.server_url + self.tool_view.uri
140
+
141
+ if self.tool_view.method == "get":
142
+ response = requests.get(url, params=tool_arguments, headers=self.headers)
143
+ elif self.tool_view.method == "post":
144
+ response = requests.post(url, json=tool_arguments, headers=self.headers)
145
+ elif self.tool_view.method == "put":
146
+ response = requests.put(url, json=tool_arguments, headers=self.headers)
147
+ elif self.tool_view.method == "delete":
148
+ response = requests.delete(url, json=tool_arguments, headers=self.headers)
149
+ else:
150
+ raise ValueError(f"method<{self.tool_view.method}> is invalid")
151
+
152
+ if response.status_code != 200:
153
+ raise ValueError(f"the resource is invalid, the error message is: {response.text}")
154
+
155
+ return response.json()
156
+
157
+ def function_call_schema(self) -> dict:
158
+ schema = self.tool_view.function_call_schema()
159
+ if self.examples is not None:
160
+ schema["examples"] = [example.to_dict() for example in self.examples]
161
+
162
+ return schema or {}
163
+
164
+
165
+ @dataclass
166
+ class RemoteToolkit:
167
+ """RemoteToolkit can be converted by openapi.yaml and endpoint"""
168
+
169
+ openapi: str
170
+ info: EndpointInfo
171
+ servers: List[Endpoint]
172
+ paths: List[RemoteToolView]
173
+
174
+ component_schemas: dict[str, Type[ToolParameterView]]
175
+ headers: dict
176
+ examples: List[Message] = field(default_factory=list)
177
+
178
+ def __getitem__(self, tool_name: str):
179
+ return self.get_tool(tool_name)
180
+
181
+ def get_tools(self) -> List[RemoteTool]:
182
+ return [
183
+ RemoteTool(
184
+ path, self.servers[0].url, self.headers, examples=self.get_examples_by_name(path.name)
185
+ )
186
+ for path in self.paths
187
+ ]
188
+
189
+ def get_examples_by_name(self, tool_name: str) -> List[Message]:
190
+ """get examples by tool-name
191
+
192
+ Args:
193
+ tool_name (str): the name of the tool
194
+
195
+ Returns:
196
+ List[Message]: the messages
197
+ """
198
+ # 1. split messages
199
+ tool_examples: List[List[Message]] = []
200
+ examples: List[Message] = []
201
+ for example in self.examples:
202
+ if isinstance(example, HumanMessage):
203
+ if len(examples) == 0:
204
+ examples.append(example)
205
+ else:
206
+ tool_examples.append(examples)
207
+ examples = [example]
208
+ else:
209
+ examples.append(example)
210
+
211
+ if len(examples) > 0:
212
+ tool_examples.append(examples)
213
+
214
+ final_exampels: List[Message] = []
215
+ # 2. find the target tool examples or empty messages
216
+ for examples in tool_examples:
217
+ tool_names = [
218
+ example.function_call.get("name", None)
219
+ for example in examples
220
+ if isinstance(example, AIMessage) and example.function_call is not None
221
+ ]
222
+ tool_names = [name for name in tool_names if name]
223
+
224
+ if tool_name in tool_names:
225
+ final_exampels.extend(examples)
226
+
227
+ return final_exampels
228
+
229
+ def get_tool(self, tool_name: str) -> RemoteTool:
230
+ paths = [path for path in self.paths if path.name == tool_name]
231
+ assert len(paths) == 1, f"tool<{tool_name}> not found in paths"
232
+ return RemoteTool(
233
+ paths[0], self.servers[0].url, self.headers, examples=self.get_examples_by_name(tool_name)
234
+ )
235
+
236
+ def to_openapi_dict(self) -> dict:
237
+ """convert plugin schema to openapi spec dict"""
238
+ spec_dict = {
239
+ "openapi": self.openapi,
240
+ "info": asdict(self.info),
241
+ "servers": [asdict(server) for server in self.servers],
242
+ "paths": {tool_view.uri: tool_view.to_openapi_dict() for tool_view in self.paths},
243
+ "components": {
244
+ "schemas": {
245
+ uri: parameters_view.to_openapi_dict()
246
+ for uri, parameters_view in self.component_schemas.items()
247
+ }
248
+ },
249
+ }
250
+ return scrub_dict(spec_dict, remove_empty_dict=True) or {}
251
+
252
+ def to_openapi_file(self, file: str):
253
+ """generate openapi configuration file
254
+
255
+ Args:
256
+ file (str): the path of the openapi yaml file
257
+ """
258
+ spec_dict = self.to_openapi_dict()
259
+ with open(file, "w+", encoding="utf-8") as f:
260
+ safe_dump(spec_dict, f, indent=4)
261
+
262
+ @classmethod
263
+ def from_openapi_dict(
264
+ cls, openapi_dict: Dict[str, Any], access_token: Optional[str] = None
265
+ ) -> RemoteToolkit:
266
+ info = EndpointInfo(**openapi_dict["info"])
267
+ servers = [Endpoint(**server) for server in openapi_dict.get("servers", [])]
268
+
269
+ # components
270
+ component_schemas = openapi_dict["components"]["schemas"]
271
+ fields = {}
272
+ for schema_name, schema in component_schemas.items():
273
+ parameter_view = ToolParameterView.from_openapi_dict(schema_name, schema)
274
+ fields[schema_name] = parameter_view
275
+
276
+ # paths
277
+ paths = []
278
+ for path, path_info in openapi_dict.get("paths", {}).items():
279
+ for method, path_method_info in path_info.items():
280
+ paths.append(
281
+ RemoteToolView.from_openapi_dict(
282
+ uri=path,
283
+ method=method,
284
+ path_info=path_method_info,
285
+ parameters_views=fields,
286
+ )
287
+ )
288
+
289
+ return RemoteToolkit(
290
+ openapi=openapi_dict["openapi"],
291
+ info=info,
292
+ servers=servers,
293
+ paths=paths,
294
+ component_schemas=fields,
295
+ headers=cls._get_authorization_headers(access_token),
296
+ ) # type: ignore
297
+
298
+ @classmethod
299
+ def from_openapi_file(cls, file: str, access_token: Optional[str] = None) -> RemoteToolkit:
300
+ """only support openapi v3.0.1
301
+
302
+ Args:
303
+ file (str): the path of openapi yaml file
304
+ access_token (Optional[str]): the path of openapi yaml file
305
+ """
306
+ if not validate_openapi_yaml(file):
307
+ raise ValueError(f"invalid openapi yaml file: {file}")
308
+
309
+ spec_dict, _ = read_from_filename(file)
310
+ return cls.from_openapi_dict(spec_dict, access_token=access_token)
311
+
312
+ @classmethod
313
+ def _get_authorization_headers(cls, access_token: Optional[str]) -> dict:
314
+ if access_token is None:
315
+ access_token = erniebot.access_token
316
+
317
+ headers = {"Content-Type": "application/json"}
318
+ if access_token is None:
319
+ logger.warning("access_token is NOT provided, this may cause 403 HTTP error..")
320
+ else:
321
+ headers["Authorization"] = f"token {access_token}"
322
+ return headers
323
+
324
+ @classmethod
325
+ def from_url(cls, url: str, access_token: Optional[str] = None) -> RemoteToolkit:
326
+ # 1. download openapy.yaml file to temp directory
327
+ if not url.endswith("/"):
328
+ url += "/"
329
+ openapi_yaml_url = url + ".well-known/openapi.yaml"
330
+
331
+ with tempfile.TemporaryDirectory() as temp_dir:
332
+ response = requests.get(openapi_yaml_url, headers=cls._get_authorization_headers(access_token))
333
+ if response.status_code != 200:
334
+ raise ValueError(f"the resource is invalid, the error message is: {response.text}")
335
+
336
+ file_content = response.content.decode("utf-8")
337
+ if not file_content.strip():
338
+ raise ValueError(f"the content is empty from: {openapi_yaml_url}")
339
+
340
+ file_path = os.path.join(temp_dir, "openapi.yaml")
341
+ with open(file_path, "w+", encoding="utf-8") as f:
342
+ f.write(file_content)
343
+
344
+ toolkit = RemoteToolkit.from_openapi_file(file_path, access_token=access_token)
345
+ for server in toolkit.servers:
346
+ server.url = url
347
+
348
+ toolkit.examples = cls.load_remote_examples_yaml(url, access_token)
349
+
350
+ return toolkit
351
+
352
+ @classmethod
353
+ def load_remote_examples_yaml(cls, url: str, access_token: Optional[str] = None) -> List[Message]:
354
+ """load remote examples by url: url/.well-known/examples.yaml
355
+
356
+ Args:
357
+ url (str): the base url of the remote toolkit
358
+ """
359
+ if not url.endswith("/"):
360
+ url += "/"
361
+ examples_yaml_url = url + ".well-known/examples.yaml"
362
+ if not url_file_exists(examples_yaml_url, cls._get_authorization_headers(access_token)):
363
+ return []
364
+
365
+ examples = []
366
+ with tempfile.TemporaryDirectory() as temp_dir:
367
+ response = requests.get(examples_yaml_url, headers=cls._get_authorization_headers(access_token))
368
+ if response.status_code != 200:
369
+ raise ValueError(
370
+ f"Invalid resource, status_code: {response.status_code}, error message: {response.text}"
371
+ )
372
+
373
+ file_content = response.content.decode("utf-8")
374
+ if not file_content.strip():
375
+ raise ValueError(f"the content is empty from: {examples_yaml_url}")
376
+
377
+ file_path = os.path.join(temp_dir, "examples.yaml")
378
+ with open(file_path, "w+", encoding="utf-8") as f:
379
+ f.write(file_content)
380
+
381
+ examples = cls.load_examples_yaml(file_path)
382
+
383
+ return examples
384
+
385
+ @classmethod
386
+ def load_examples_dict(cls, examples_dict: Dict[str, Any]) -> List[Message]:
387
+ messages: List[Message] = []
388
+ for examples in examples_dict["examples"]:
389
+ examples = examples["context"]
390
+ for example in examples:
391
+ if "user" == example["role"]:
392
+ messages.append(HumanMessage(example["content"]))
393
+ elif "bot" in example["role"]:
394
+ plugin = example["plugin"]
395
+ if "operationId" in plugin:
396
+ function_call: FunctionCall = {
397
+ "name": plugin["operationId"],
398
+ "thoughts": plugin["thoughts"],
399
+ "arguments": json.dumps(plugin["requestArguments"], ensure_ascii=False),
400
+ }
401
+ else:
402
+ function_call = {
403
+ "name": "",
404
+ "thoughts": plugin["thoughts"],
405
+ "arguments": "{}",
406
+ } # type: ignore
407
+ messages.append(AIMessage("", function_call=function_call))
408
+ else:
409
+ raise ValueError(f"invald role: <{example['role']}>")
410
+ return messages
411
+
412
+ @classmethod
413
+ def load_examples_yaml(cls, file: str) -> List[Message]:
414
+ """load examples from yaml file
415
+
416
+ Args:
417
+ file (str): the path of examples file
418
+
419
+ Returns:
420
+ List[Message]: the list of messages
421
+ """
422
+ content: dict = read_from_filename(file)[0]
423
+ if len(content) == 0 or "examples" not in content:
424
+ raise ValueError("invalid examples configuration file")
425
+ return cls.load_examples_dict(content)
426
+
427
+ def function_call_schemas(self) -> List[dict]:
428
+ return [tool.function_call_schema() for tool in self.get_tools()]
erniebot-agent/erniebot_agent/tools/calculator_tool.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Type
4
+
5
+ from erniebot_agent.messages import AIMessage, HumanMessage, Message
6
+ from erniebot_agent.tools.schema import ToolParameterView
7
+ from pydantic import Field
8
+
9
+ from .base import Tool
10
+
11
+
12
+ class CalculatorToolInputView(ToolParameterView):
13
+ math_formula: str = Field(description='标准的数学公式,例如:"2+3"、"3 - 4 * 6", "(3 + 4) * (6 + 4)" 等。 ')
14
+
15
+
16
+ class CalculatorToolOutputView(ToolParameterView):
17
+ formula_result: float = Field(description="数学公式计算的结果")
18
+
19
+
20
+ class CalculatorTool(Tool):
21
+ description: str = "CalculatorTool用于执行数学公式计算"
22
+ input_type: Type[ToolParameterView] = CalculatorToolInputView
23
+ ouptut_type: Type[ToolParameterView] = CalculatorToolOutputView
24
+
25
+ async def __call__(self, math_formula: str) -> Dict[str, float]:
26
+ return {"formula_result": eval(math_formula)}
27
+
28
+ @property
29
+ def examples(self) -> List[Message]:
30
+ return [
31
+ HumanMessage("请告诉我三加六等于多少?"),
32
+ AIMessage(
33
+ "",
34
+ function_call={
35
+ "name": self.tool_name,
36
+ "thoughts": f"用户想知道3加6等于多少,我可以使用{self.tool_name}工具来计算公式,其中`math_formula`字段的内容为:'3+6'。",
37
+ "arguments": '{"math_formula": "3+6"}',
38
+ },
39
+ token_usage={
40
+ "prompt_tokens": 5,
41
+ "completion_tokens": 7,
42
+ }, # TODO: Functional AIMessage will not add in the memory, will it add token_usage?
43
+ ),
44
+ HumanMessage("一加八再乘以5是多少?"),
45
+ AIMessage(
46
+ "",
47
+ function_call={
48
+ "name": self.tool_name,
49
+ "thoughts": f"用户想知道1加8再乘5等于多少,我可以使用{self.tool_name}工具来计算公式,"
50
+ "其中`math_formula`字段的内容为:'(1+8)*5'。",
51
+ "arguments": '{"math_formula": "(1+8)*5"}',
52
+ },
53
+ token_usage={"prompt_tokens": 5, "completion_tokens": 7}, # For test only
54
+ ),
55
+ HumanMessage("我想知道十二除以四再加五等于多少?"),
56
+ AIMessage(
57
+ "",
58
+ function_call={
59
+ "name": self.tool_name,
60
+ "thoughts": f"用户想知道12除以4再加5等于多少,我可以使用{self.tool_name}工具来计算公式,"
61
+ "其中`math_formula`字段的内容为:'12/4+5'。",
62
+ "arguments": '{"math_formula": "12/4+5"}',
63
+ },
64
+ token_usage={"prompt_tokens": 5, "completion_tokens": 7}, # For test only
65
+ ),
66
+ ]
erniebot-agent/erniebot_agent/tools/current_time_tool.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Dict, List, Type
5
+
6
+ from erniebot_agent.messages import AIMessage, HumanMessage, Message
7
+ from erniebot_agent.tools.schema import ToolParameterView
8
+ from pydantic import Field
9
+
10
+ from .base import Tool
11
+
12
+
13
+ class CurrentTimeToolOutputView(ToolParameterView):
14
+ current_time: str = Field(description="当前时间")
15
+
16
+
17
+ class CurrentTimeTool(Tool):
18
+ description: str = "CurrentTimeTool 用于获取当前时间"
19
+ ouptut_type: Type[ToolParameterView] = CurrentTimeToolOutputView
20
+
21
+ async def __call__(self) -> Dict[str, str]:
22
+ return {"current_time": datetime.strftime(datetime.now(), "%Y年%m月%d号 %点:%分:%秒")}
23
+
24
+ @property
25
+ def examples(self) -> List[Message]:
26
+ return [
27
+ HumanMessage("现在几点钟了"),
28
+ AIMessage(
29
+ "",
30
+ function_call={
31
+ "name": self.tool_name,
32
+ "thoughts": f"用户想知道现在几点了,我可以使用{self.tool_name}来获取当前时间,并从其中获得当前小时时间。",
33
+ "arguments": "{}",
34
+ },
35
+ token_usage={"prompt_tokens": 5, "completion_tokens": 7}, # For test only
36
+ ),
37
+ HumanMessage("现在是什么时候?"),
38
+ AIMessage(
39
+ "",
40
+ function_call={
41
+ "name": self.tool_name,
42
+ "thoughts": f"用户想知道现在几点了,我可以使用{self.tool_name}来获取当前时间",
43
+ "arguments": "{}",
44
+ },
45
+ token_usage={"prompt_tokens": 5, "completion_tokens": 7}, # For test only
46
+ ),
47
+ ]
erniebot-agent/erniebot_agent/tools/image_generation_tool.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import uuid
19
+ from typing import Any, Dict, List, Optional, Type
20
+
21
+ from erniebot_agent.messages import AIMessage, HumanMessage, Message
22
+ from erniebot_agent.tools.base import Tool
23
+ from erniebot_agent.tools.schema import ToolParameterView
24
+ from erniebot_agent.utils.common import download_file, get_cache_dir
25
+ from pydantic import Field
26
+
27
+ import erniebot
28
+
29
+
30
+ class ImageGenerationInputView(ToolParameterView):
31
+ prompt: str = Field(description="描述图像内容、风格的文本。例如:生成一张月亮的照片,月亮很圆。")
32
+ width: int = Field(description="生成图片的宽度")
33
+ height: int = Field(description="生成图片的高度")
34
+ image_num: int = Field(description="生成图片的数量")
35
+
36
+
37
+ class ImageGenerationOutputView(ToolParameterView):
38
+ image_path: str = Field(description="图片在本地机器上的保存路径")
39
+
40
+
41
+ class ImageGenerationTool(Tool):
42
+ description: str = "AI作图、生成图片、画图的工具"
43
+ input_type: Type[ToolParameterView] = ImageGenerationInputView
44
+ ouptut_type: Type[ToolParameterView] = ImageGenerationOutputView
45
+
46
+ def __init__(
47
+ self,
48
+ yinian_access_token: Optional[str] = None,
49
+ yinian_ak: Optional[str] = None,
50
+ yinian_sk: Optional[str] = None,
51
+ ) -> None:
52
+ self.config: Dict[str, Optional[Any]]
53
+ if yinian_access_token is not None:
54
+ self.config = {"api_type": "yinian", "access_token": yinian_access_token}
55
+ elif yinian_ak is not None and yinian_sk is not None:
56
+ self.config = {"api_type": "yinian", "ak": yinian_ak, "sk": yinian_sk}
57
+ else:
58
+ raise ValueError("Please set the yinian_access_token, or set yinian_ak and yinian_sk")
59
+
60
+ async def __call__(
61
+ self,
62
+ prompt: str,
63
+ width: int = 512,
64
+ height: int = 512,
65
+ image_num: int = 1,
66
+ ) -> Dict[str, List[str]]:
67
+ response = erniebot.Image.create(
68
+ model="ernie-vilg-v2",
69
+ prompt=prompt,
70
+ width=width,
71
+ height=height,
72
+ image_num=image_num,
73
+ _config_=self.config,
74
+ )
75
+
76
+ image_path = []
77
+ cache_dir = get_cache_dir()
78
+ for item in response["data"]["sub_task_result_list"]:
79
+ image_url = item["final_image_list"][0]["img_url"]
80
+ save_path = os.path.join(cache_dir, f"img_{uuid.uuid1()}.png")
81
+ download_file(image_url, save_path)
82
+ image_path.append(save_path)
83
+ return {"image_path": image_path}
84
+
85
+ @property
86
+ def examples(self) -> List[Message]:
87
+ return [
88
+ HumanMessage("画一张小狗的图片,图像高度512,图像宽度512"),
89
+ AIMessage(
90
+ "",
91
+ function_call={
92
+ "name": "ImageGenerationTool",
93
+ "thoughts": "用户需要我生成一张小狗的图片,图像高度为512,宽度为512。我可以使用ImageGenerationTool工具来满足用户的需求。",
94
+ "arguments": '{"prompt":"画一张小狗的图片,图像高度512,图像宽度512",'
95
+ '"width":512,"height":512,"image_num":1}',
96
+ },
97
+ ),
98
+ HumanMessage("生成两张天空的图片"),
99
+ AIMessage(
100
+ "",
101
+ function_call={
102
+ "name": self.tool_name,
103
+ "thoughts": "用户想要生成两张天空的图片,我需要调用ImageGenerationTool工具的call接口,"
104
+ "并设置prompt为'生成两张天空的图片',width和height可以默认为512,image_num默认为2。",
105
+ "arguments": '{"prompt":"生成两张天空的图片","width":512,"height":512,"image_num":2}',
106
+ },
107
+ ),
108
+ HumanMessage("使用AI作图工具,生成1张小猫的图片,高度和高度是1024"),
109
+ AIMessage(
110
+ "",
111
+ function_call={
112
+ "name": self.tool_name,
113
+ "thoughts": "用户需要生成一张小猫的图片,高度和宽度都是1024。我可以使用ImageGenerationTool工具来满足用户的需求。",
114
+ "arguments": '{"prompt":"生成一张小猫的照片。","width":1024,"height":1024,"image_num":1}',
115
+ },
116
+ ),
117
+ ]
erniebot-agent/erniebot_agent/tools/schema.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import inspect
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Type, get_args
20
+
21
+ from erniebot_agent.utils.logging import logger
22
+ from pydantic import BaseModel, Field, create_model
23
+ from pydantic.fields import FieldInfo
24
+
25
+ INVALID_FIELD_NAME = "__invalid_field_name__"
26
+
27
+
28
+ def is_optional_type(type: Type):
29
+ args = get_args(type)
30
+ if len(args) == 0:
31
+ return False
32
+
33
+ return len([arg for arg in args if arg is None.__class__]) > 0
34
+
35
+
36
+ def get_typing_list_type(type):
37
+ """get typing.List[T] element type
38
+
39
+ Args:
40
+ type (typing.List): Generics type
41
+ """
42
+ # 1. checking list type
43
+ if getattr(type, "_name", None) != "List":
44
+ return None
45
+
46
+ arg_type = get_args(type)[0]
47
+ return json_type(arg_type)
48
+
49
+
50
+ def json_type(type: Optional[Type[object]] = None):
51
+ if type is None:
52
+ return "object"
53
+
54
+ mapping = {
55
+ int: "integer",
56
+ str: "string",
57
+ list: "array",
58
+ List: "array",
59
+ float: "number",
60
+ ToolParameterView: "object",
61
+ }
62
+
63
+ if inspect.isclass(type) and issubclass(type, ToolParameterView):
64
+ return "object"
65
+
66
+ if getattr(type, "_name", None) == "List":
67
+ return "array"
68
+
69
+ if type not in mapping:
70
+ args = [arg for arg in get_args(type) if arg is not None.__class__]
71
+ if len(args) > 1 or len(args) == 0:
72
+ raise ValueError(
73
+ "only support simple type: FieldType=int/str/float/ToolParameterView, "
74
+ "so the target type should be one of: FieldType, List[FieldType], "
75
+ f"Optional[FieldType], but receive {type}"
76
+ )
77
+ type = args[0]
78
+
79
+ if type in mapping:
80
+ return mapping[type]
81
+
82
+ if inspect.isclass(type) and issubclass(type, ToolParameterView):
83
+ return "object"
84
+
85
+ return str(type)
86
+
87
+
88
+ def python_type_from_json_type(json_type_dict: dict) -> Type[object]:
89
+ simple_types = {"integer": int, "string": str, "number": float, "object": ToolParameterView}
90
+ if json_type_dict["type"] in simple_types:
91
+ return simple_types[json_type_dict["type"]]
92
+
93
+ assert (
94
+ json_type_dict["type"] == "array"
95
+ ), f"only support simple_types<{','.join(simple_types)}> and array type"
96
+ assert "type" in json_type_dict["items"], "<items> field must be defined when 'type'=array"
97
+
98
+ json_type_value = json_type_dict["items"]["type"]
99
+ if json_type_value == "string":
100
+ return List[str]
101
+ if json_type_value == "integer":
102
+ return List[int]
103
+ if json_type_value == "number":
104
+ return List[float]
105
+ if json_type_value == "object":
106
+ return List[ToolParameterView]
107
+
108
+ raise ValueError(f"unsupported data type: {json_type_value}")
109
+
110
+
111
+ def scrub_dict(d: dict, remove_empty_dict: bool = False) -> Optional[dict]:
112
+ """remove empty Value node,
113
+
114
+ function_call_schema: require
115
+
116
+ Args:
117
+ d (dict): the instance of dictionary
118
+ remove_empty_dict (bool): whether remove empty dict
119
+
120
+ Returns:
121
+ dict: the dictionary data after slimming down
122
+ """
123
+ if type(d) is dict:
124
+ result = {}
125
+ for k, v in d.items():
126
+ v = scrub_dict(v, remove_empty_dict)
127
+ if v is not None:
128
+ result[k] = v
129
+
130
+ if len(result) == 0:
131
+ if not remove_empty_dict:
132
+ return {}
133
+ return None
134
+
135
+ return result
136
+ elif isinstance(d, list):
137
+ return [scrub_dict(item, remove_empty_dict) for item in d] # type: ignore
138
+ else:
139
+ return d
140
+
141
+
142
+ class OpenAPIProperty(BaseModel):
143
+ type: str
144
+ description: Optional[str] = None
145
+ required: Optional[List[str]] = None
146
+ items: dict = Field(default_factory=dict)
147
+ properties: dict = Field(default_factory=dict)
148
+
149
+
150
+ def get_field_openapi_property(field_info: FieldInfo) -> OpenAPIProperty:
151
+ """convert pydantic FieldInfo instance to OpenAPIProperty value
152
+
153
+ Args:
154
+ field_info (FieldInfo): the field instance
155
+
156
+ Returns:
157
+ OpenAPIProperty: the converted OpenAPI Property
158
+ """
159
+ typing_list_type = get_typing_list_type(field_info.annotation)
160
+ if typing_list_type is not None:
161
+ field_type = "array"
162
+ elif is_optional_type(field_info.annotation):
163
+ field_type = json_type(get_args(field_info.annotation)[0])
164
+ else:
165
+ field_type = json_type(field_info.annotation)
166
+
167
+ property = {
168
+ "type": field_type,
169
+ "description": field_info.description,
170
+ }
171
+
172
+ if property["type"] == "array":
173
+ if typing_list_type == "object":
174
+ list_type: Type[ToolParameterView] = get_args(field_info.annotation)[0]
175
+ property["items"] = list_type.to_openapi_dict()
176
+ else:
177
+ property["items"] = {"type": typing_list_type}
178
+ elif property["type"] == "object":
179
+ if is_optional_type(field_info.annotation):
180
+ field_type_class: Type[ToolParameterView] = get_args(field_info.annotation)[0]
181
+ else:
182
+ field_type_class = field_info.annotation
183
+
184
+ openapi_dict = field_type_class.to_openapi_dict()
185
+ property.update(openapi_dict)
186
+
187
+ property["description"] = property.get("description", "")
188
+ return OpenAPIProperty(**property)
189
+
190
+
191
+ class ToolParameterView(BaseModel):
192
+ @classmethod
193
+ def from_openapi_dict(cls, name, schema: dict) -> Type[ToolParameterView]:
194
+ """parse openapi component schemas to ParameterView
195
+ Args:
196
+ response_or_returns (dict): the content of status code
197
+
198
+ Returns:
199
+ _type_: _description_
200
+ """
201
+
202
+ # TODO(wj-Mcat): to load Optional field
203
+ fields = {}
204
+ for field_name, field_dict in schema.get("properties", {}).items():
205
+ field_type = python_type_from_json_type(field_dict)
206
+
207
+ if field_type is List[ToolParameterView]:
208
+ SubParameterView: Type[ToolParameterView] = ToolParameterView.from_openapi_dict(
209
+ field_name, field_dict["items"]
210
+ )
211
+ field_type = List[SubParameterView] # type: ignore
212
+
213
+ # TODO(wj-Mcat): remove supporting for `summary` field
214
+ if "summary" in field_dict:
215
+ description = field_dict["summary"]
216
+ logger.info("`summary` field will be deprecated, please use `description`")
217
+
218
+ if "description" in field_dict:
219
+ logger.info("`description` field will be used instead of `summary`")
220
+ description = field_dict["description"]
221
+ else:
222
+ description = field_dict.get("description", None)
223
+
224
+ description = description or ""
225
+
226
+ field = FieldInfo(annotation=field_type, description=description)
227
+
228
+ # TODO(wj-Mcat): to handle list field required & not-required
229
+ # if get_typing_list_type(field_type) is not None:
230
+ # field.default_factory = list
231
+
232
+ fields[field_name] = (field_type, field)
233
+
234
+ return create_model("OpenAPIParameterView", __base__=ToolParameterView, **fields)
235
+
236
+ @classmethod
237
+ def to_openapi_dict(cls) -> dict:
238
+ """convert ParametersView to openapi spec dict
239
+
240
+ Returns:
241
+ dict: schema of openapi
242
+ """
243
+
244
+ required_names, properties = [], {}
245
+ for field_name, field_info in cls.model_fields.items():
246
+ if field_info.is_required() and not is_optional_type(field_info.annotation):
247
+ required_names.append(field_name)
248
+
249
+ properties[field_name] = dict(get_field_openapi_property(field_info))
250
+
251
+ result = {
252
+ "type": "object",
253
+ "properties": properties,
254
+ }
255
+ if len(required_names) > 0:
256
+ result["required"] = required_names
257
+ result = scrub_dict(result, remove_empty_dict=True) # type: ignore
258
+ return result or {}
259
+
260
+ @classmethod
261
+ def function_call_schema(cls) -> dict:
262
+ """get function_call schame
263
+
264
+ Returns:
265
+ dict: the schema of function_call
266
+ """
267
+ return cls.to_openapi_dict()
268
+
269
+ @classmethod
270
+ def from_dict(cls, field_map: Dict[str, Any]):
271
+ """
272
+ Class method to create a Pydantic model dynamically based on a dictionary.
273
+
274
+ Args:
275
+ field_map (Dict[str, Any]): A dictionary mapping field names to their corresponding type
276
+ and description.
277
+
278
+ Returns:
279
+ PydanticModel: A dynamically created Pydantic model with fields specified by the
280
+ input dictionary.
281
+
282
+ Note:
283
+ This method is used to create a Pydantic model dynamically based on the provided dictionary,
284
+ where each field's type and description are specified in the input.
285
+
286
+ """
287
+ fields = {}
288
+ for field_name, field_dict in field_map.items():
289
+ field_type = field_dict["type"]
290
+ description = field_dict["description"]
291
+ field = FieldInfo(annotation=field_type, description=description)
292
+ fields[field_name] = (field_type, field)
293
+ return create_model(cls.__name__, __base__=ToolParameterView, **fields)
294
+
295
+
296
+ @dataclass
297
+ class RemoteToolView:
298
+ uri: str
299
+ method: str
300
+ name: str
301
+ description: str
302
+ parameters: Optional[Type[ToolParameterView]] = None
303
+ parameters_description: Optional[str] = None
304
+ returns: Optional[Type[ToolParameterView]] = None
305
+ returns_description: Optional[str] = None
306
+
307
+ returns_ref_uri: Optional[str] = None
308
+ parameters_ref_uri: Optional[str] = None
309
+
310
+ def to_openapi_dict(self):
311
+ result = {
312
+ "operationId": self.name,
313
+ "description": self.description,
314
+ }
315
+ if self.returns is not None:
316
+ response = {
317
+ "200": {
318
+ "description": self.returns_description,
319
+ "content": {
320
+ "application/json": {
321
+ "schema": {"$ref": "#/components/schemas/" + (self.returns_ref_uri or "")}
322
+ }
323
+ },
324
+ }
325
+ }
326
+ result["responses"] = response
327
+
328
+ if self.parameters is not None:
329
+ parameters = {
330
+ "required": True,
331
+ "content": {
332
+ "application/json": {
333
+ "schema": {"$ref": "#/components/schemas/" + (self.parameters_ref_uri or "")}
334
+ }
335
+ },
336
+ }
337
+ result["requestBody"] = parameters
338
+ return {self.method: result}
339
+
340
+ @staticmethod
341
+ def from_openapi_dict(
342
+ uri: str, method: str, path_info: dict, parameters_views: dict[str, Type[ToolParameterView]]
343
+ ) -> RemoteToolView:
344
+ """construct RemoteToolView from openapi spec-dict info
345
+
346
+ Args:
347
+ uri (str): the url path of remote tool
348
+ method (str): http method: one of [get, post, put, delete]
349
+ path_info (dict): the spec info of remote tool
350
+ parameters_views (dict[str, ParametersView]):
351
+ the dict of parameters views which are the schema of input/output of tool
352
+
353
+ Returns:
354
+ RemoteToolView: the instance of remote tool view
355
+ """
356
+ parameters_ref_uri, returns_ref_uri = None, None
357
+ parameters, parameters_description = None, None
358
+ if "requestBody" in path_info:
359
+ request_ref = path_info["requestBody"]["content"]["application/json"]["schema"]["$ref"]
360
+ parameters_ref_uri = request_ref.split("/")[-1]
361
+ assert parameters_ref_uri in parameters_views
362
+ parameters = parameters_views[parameters_ref_uri]
363
+ parameters_description = path_info["requestBody"].get("description", None)
364
+
365
+ returns, returns_description = None, None
366
+ if "responses" in path_info:
367
+ response_ref = list(path_info["responses"].values())[0]["content"]["application/json"]["schema"][
368
+ "$ref"
369
+ ]
370
+ returns_ref_uri = response_ref.split("/")[-1]
371
+ assert returns_ref_uri in parameters_views
372
+ returns = parameters_views[returns_ref_uri]
373
+ returns_description = list(path_info["responses"].values())[0].get("description", None)
374
+
375
+ return RemoteToolView(
376
+ name=path_info["operationId"],
377
+ parameters=parameters,
378
+ parameters_description=parameters_description,
379
+ returns=returns,
380
+ returns_description=returns_description,
381
+ description=path_info.get("description", path_info.get("summary", None)),
382
+ method=method,
383
+ uri=uri,
384
+ # save ref id info
385
+ returns_ref_uri=returns_ref_uri,
386
+ parameters_ref_uri=parameters_ref_uri,
387
+ )
388
+
389
+ def function_call_schema(self):
390
+ inputs = {
391
+ "name": self.name,
392
+ "description": self.description,
393
+ # TODO(wj-Mcat): read examples from openapi.yaml
394
+ # "examples": [example.to_dict() for example in self.examples],
395
+ }
396
+ if self.parameters is not None:
397
+ inputs["parameters"] = self.parameters.function_call_schema() # type: ignore
398
+ else:
399
+ inputs["parameters"] = {"type": "object", "properties": {}}
400
+
401
+ if self.returns is not None:
402
+ inputs["responses"] = self.returns.function_call_schema() # type: ignore
403
+ return scrub_dict(inputs) or {}
404
+
405
+
406
+ @dataclass
407
+ class Endpoint:
408
+ url: str
409
+
410
+
411
+ @dataclass
412
+ class EndpointInfo:
413
+ title: str
414
+ description: str
415
+ version: str
erniebot-agent/erniebot_agent/tools/tool_manager.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ from typing import Dict, List, final
17
+
18
+ from erniebot_agent.tools.base import Tool
19
+
20
+
21
+ @final
22
+ class ToolManager(object):
23
+ """A `ToolManager` instance manages tools for an agent.
24
+
25
+ This implementation is based on `ToolsManager` in
26
+ https://github.com/deepset-ai/haystack/blob/main/haystack/agents/base.py
27
+ """
28
+
29
+ def __init__(self, tools: List[Tool]) -> None:
30
+ super().__init__()
31
+ self._tools: Dict[str, Tool] = {}
32
+ for tool in tools:
33
+ self.add_tool(tool)
34
+
35
+ def __getitem__(self, tool_name: str) -> Tool:
36
+ return self.get_tool(tool_name)
37
+
38
+ def add_tool(self, tool: Tool) -> None:
39
+ tool_name = tool.tool_name
40
+ if tool_name in self._tools:
41
+ raise RuntimeError(f"Name {repr(tool_name)} is already registered.")
42
+ self._tools[tool_name] = tool
43
+
44
+ def remove_tool(self, tool: Tool) -> None:
45
+ tool_name = tool.tool_name
46
+ if tool_name not in self._tools:
47
+ raise RuntimeError(f"Name {repr(tool_name)} is not registered.")
48
+ if self._tools[tool_name] is not tool:
49
+ raise RuntimeError(f"The tool with the registered name {repr(tool_name)} is not the given tool.")
50
+ self._tools.pop(tool_name)
51
+
52
+ def get_tool(self, tool_name: str) -> Tool:
53
+ if tool_name not in self._tools:
54
+ raise RuntimeError(f"Name {repr(tool_name)} is not registered.")
55
+ return self._tools[tool_name]
56
+
57
+ def get_tools(self) -> List[Tool]:
58
+ return list(self._tools.values())
59
+
60
+ def get_tool_names(self) -> str:
61
+ return ", ".join(self._tools.keys())
62
+
63
+ def get_tool_names_with_descriptions(self) -> str:
64
+ return "\n".join(
65
+ f"{name}:{json.dumps(tool.function_call_schema())}" for name, tool in self._tools.items()
66
+ )
67
+
68
+ def get_tool_schemas(self):
69
+ return [tool.function_call_schema() for tool in self._tools.values()]
erniebot-agent/erniebot_agent/utils/__init__.py ADDED
File without changes