百度文心一言的例子
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- erniebot-agent/README.md +44 -0
- erniebot-agent/erniebot_agent/__init__.py +19 -0
- erniebot-agent/erniebot_agent/agents/__init__.py +13 -0
- erniebot-agent/erniebot_agent/agents/base.py +279 -0
- erniebot-agent/erniebot_agent/agents/callback/__init__.py +13 -0
- erniebot-agent/erniebot_agent/agents/callback/callback_manager.py +94 -0
- erniebot-agent/erniebot_agent/agents/callback/default.py +22 -0
- erniebot-agent/erniebot_agent/agents/callback/event.py +26 -0
- erniebot-agent/erniebot_agent/agents/callback/handlers/__init__.py +13 -0
- erniebot-agent/erniebot_agent/agents/callback/handlers/base.py +55 -0
- erniebot-agent/erniebot_agent/agents/callback/handlers/logging_handler.py +107 -0
- erniebot-agent/erniebot_agent/agents/functional_agent.py +148 -0
- erniebot-agent/erniebot_agent/agents/schema.py +93 -0
- erniebot-agent/erniebot_agent/chat_models/__init__.py +17 -0
- erniebot-agent/erniebot_agent/chat_models/base.py +60 -0
- erniebot-agent/erniebot_agent/chat_models/erniebot.py +135 -0
- erniebot-agent/erniebot_agent/extensions/langchain/chat_models/__init__.py +1 -0
- erniebot-agent/erniebot_agent/extensions/langchain/chat_models/erniebot.py +356 -0
- erniebot-agent/erniebot_agent/extensions/langchain/embeddings/__init__.py +1 -0
- erniebot-agent/erniebot_agent/extensions/langchain/embeddings/ernie.py +82 -0
- erniebot-agent/erniebot_agent/extensions/langchain/llms/__init__.py +1 -0
- erniebot-agent/erniebot_agent/extensions/langchain/llms/erniebot.py +239 -0
- erniebot-agent/erniebot_agent/file_io/__init__.py +13 -0
- erniebot-agent/erniebot_agent/file_io/base.py +46 -0
- erniebot-agent/erniebot_agent/file_io/file_manager.py +138 -0
- erniebot-agent/erniebot_agent/file_io/file_registry.py +55 -0
- erniebot-agent/erniebot_agent/file_io/local_file.py +55 -0
- erniebot-agent/erniebot_agent/file_io/protocol.py +57 -0
- erniebot-agent/erniebot_agent/file_io/remote_file.py +153 -0
- erniebot-agent/erniebot_agent/memory/__init__.py +18 -0
- erniebot-agent/erniebot_agent/memory/base.py +99 -0
- erniebot-agent/erniebot_agent/memory/limit_token_memory.py +59 -0
- erniebot-agent/erniebot_agent/memory/sliding_window_memory.py +41 -0
- erniebot-agent/erniebot_agent/memory/whole_memory.py +19 -0
- erniebot-agent/erniebot_agent/messages.py +124 -0
- erniebot-agent/erniebot_agent/prompt/__init__.py +16 -0
- erniebot-agent/erniebot_agent/prompt/base.py +28 -0
- erniebot-agent/erniebot_agent/prompt/prompt_template.py +80 -0
- erniebot-agent/erniebot_agent/retrieval/__init__.py +0 -0
- erniebot-agent/erniebot_agent/retrieval/baizhong_search.py +296 -0
- erniebot-agent/erniebot_agent/retrieval/document.py +123 -0
- erniebot-agent/erniebot_agent/tools/__init__.py +15 -0
- erniebot-agent/erniebot_agent/tools/baizhong_tool.py +65 -0
- erniebot-agent/erniebot_agent/tools/base.py +428 -0
- erniebot-agent/erniebot_agent/tools/calculator_tool.py +66 -0
- erniebot-agent/erniebot_agent/tools/current_time_tool.py +47 -0
- erniebot-agent/erniebot_agent/tools/image_generation_tool.py +117 -0
- erniebot-agent/erniebot_agent/tools/schema.py +415 -0
- erniebot-agent/erniebot_agent/tools/tool_manager.py +69 -0
- erniebot-agent/erniebot_agent/utils/__init__.py +0 -0
erniebot-agent/README.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
<h1>ERNIE Bot Agent</h1>
|
4 |
+
|
5 |
+
ERNIE Bot Agent 可以快速开发智能体。
|
6 |
+
|
7 |
+
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
|
8 |
+
![python version](https://img.shields.io/badge/python-3.8+-orange.svg)
|
9 |
+
![support os](https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-yellow.svg)
|
10 |
+
|
11 |
+
</div>
|
12 |
+
|
13 |
+
`ERNIE Bot Agent` 旨在为开发者提供快速搭建大模型Agent和应用的框架。该项目还在积极研发中,敬请期待我们后续的正式发版。
|
14 |
+
|
15 |
+
## 主要功能
|
16 |
+
|
17 |
+
### 大模型 Agent 框架
|
18 |
+
|
19 |
+
`ERNIE Bot Agent` 将结合飞桨星河AI Studio社区,为开发者提供一站式的大模型Agent和应用搭建框架和平台。该项目还在积极研发中,敬请期待我们后续的正式发版。
|
20 |
+
|
21 |
+
### 文心 LangChain 插件
|
22 |
+
|
23 |
+
为了让大家更加高效、便捷地结合文心大模型与LangChain进行开发,`ERNIE Bot Agent`对`LangChain`框架进行了功能扩展,提供了基于文心大模型的大语言模型(LLM)组件、聊天模型(ChatModel)组件以及文本嵌入模型(Text Embedding Model)组件。详情请参见[使用范例Notebook](https://github.com/PaddlePaddle/ERNIE-Bot-SDK/blob/develop/erniebot-agent/examples/cookbook/how_to_use_langchain_extension.ipynb)。
|
24 |
+
|
25 |
+
|
26 |
+
## 快速安装
|
27 |
+
|
28 |
+
建议您可以使用pip快速安装 ERNIE Bot Agent 的最新稳定版。
|
29 |
+
|
30 |
+
```shell
|
31 |
+
pip install --upgrade erniebot-agent
|
32 |
+
```
|
33 |
+
|
34 |
+
如需使用develop版本,可以下载源码后执行如下命令安装
|
35 |
+
|
36 |
+
```shell
|
37 |
+
git clone https://github.com/PaddlePaddle/ERNIE-Bot-SDK.git
|
38 |
+
cd ERNIE-Bot-SDK/erniebot-agent
|
39 |
+
pip install .
|
40 |
+
```
|
41 |
+
|
42 |
+
## License
|
43 |
+
|
44 |
+
ERNIE Bot Agent遵循Apache-2.0开源协议。
|
erniebot-agent/erniebot_agent/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from erniebot_agent.utils.logging import logger, setup_logging
|
16 |
+
|
17 |
+
__all__ = ["logger"]
|
18 |
+
|
19 |
+
setup_logging()
|
erniebot-agent/erniebot_agent/agents/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
erniebot-agent/erniebot_agent/agents/base.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import abc
|
16 |
+
import inspect
|
17 |
+
import json
|
18 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
19 |
+
|
20 |
+
from erniebot_agent.agents.callback.callback_manager import CallbackManager
|
21 |
+
from erniebot_agent.agents.callback.default import get_default_callbacks
|
22 |
+
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
|
23 |
+
from erniebot_agent.agents.schema import (
|
24 |
+
AgentFile,
|
25 |
+
AgentResponse,
|
26 |
+
LLMResponse,
|
27 |
+
ToolResponse,
|
28 |
+
)
|
29 |
+
from erniebot_agent.chat_models.base import ChatModel
|
30 |
+
from erniebot_agent.file_io.file_manager import FileManager
|
31 |
+
from erniebot_agent.file_io.protocol import is_local_file_id, is_remote_file_id
|
32 |
+
from erniebot_agent.memory.base import Memory
|
33 |
+
from erniebot_agent.messages import Message, SystemMessage
|
34 |
+
from erniebot_agent.tools.base import Tool
|
35 |
+
from erniebot_agent.tools.tool_manager import ToolManager
|
36 |
+
from erniebot_agent.utils.logging import logger
|
37 |
+
|
38 |
+
|
39 |
+
class BaseAgent(metaclass=abc.ABCMeta):
|
40 |
+
llm: ChatModel
|
41 |
+
memory: Memory
|
42 |
+
|
43 |
+
@abc.abstractmethod
|
44 |
+
async def async_run(self, prompt: str) -> AgentResponse:
|
45 |
+
raise NotImplementedError
|
46 |
+
|
47 |
+
|
48 |
+
class Agent(BaseAgent):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
llm: ChatModel,
|
52 |
+
tools: Union[ToolManager, List[Tool]],
|
53 |
+
memory: Memory,
|
54 |
+
system_message: Optional[SystemMessage] = None,
|
55 |
+
*,
|
56 |
+
callbacks: Optional[Union[CallbackManager, List[CallbackHandler]]] = None,
|
57 |
+
file_manager: Optional[FileManager] = None,
|
58 |
+
) -> None:
|
59 |
+
super().__init__()
|
60 |
+
self.llm = llm
|
61 |
+
self.memory = memory
|
62 |
+
# 1. Get system message exist in memory
|
63 |
+
# OR 2. overwrite by the system_message paased in the Agent.
|
64 |
+
if system_message:
|
65 |
+
self.system_message = system_message
|
66 |
+
else:
|
67 |
+
self.system_message = memory.get_system_message()
|
68 |
+
if isinstance(tools, ToolManager):
|
69 |
+
self._tool_manager = tools
|
70 |
+
else:
|
71 |
+
self._tool_manager = ToolManager(tools)
|
72 |
+
if callbacks is None:
|
73 |
+
callbacks = get_default_callbacks()
|
74 |
+
if isinstance(callbacks, CallbackManager):
|
75 |
+
self._callback_manager = callbacks
|
76 |
+
else:
|
77 |
+
self._callback_manager = CallbackManager(callbacks)
|
78 |
+
self.file_manager = file_manager
|
79 |
+
|
80 |
+
async def async_run(self, prompt: str) -> AgentResponse:
|
81 |
+
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
|
82 |
+
agent_resp = await self._async_run(prompt)
|
83 |
+
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
|
84 |
+
return agent_resp
|
85 |
+
|
86 |
+
def load_tool(self, tool: Tool) -> None:
|
87 |
+
self._tool_manager.add_tool(tool)
|
88 |
+
|
89 |
+
def unload_tool(self, tool: Tool) -> None:
|
90 |
+
self._tool_manager.remove_tool(tool)
|
91 |
+
|
92 |
+
def reset_memory(self) -> None:
|
93 |
+
self.memory.clear_chat_history()
|
94 |
+
|
95 |
+
def launch_gradio_demo(self, **launch_kwargs: Any):
|
96 |
+
# TODO: Unified optional dependencies management
|
97 |
+
try:
|
98 |
+
import gradio as gr
|
99 |
+
except ImportError:
|
100 |
+
raise ImportError(
|
101 |
+
"Could not import gradio, which is required for `launch_gradio_demo()`."
|
102 |
+
" Please run `pip install erniebot-agent[gradio]` to install the optional dependencies."
|
103 |
+
) from None
|
104 |
+
|
105 |
+
raw_messages = []
|
106 |
+
|
107 |
+
def _pre_chat(text, history):
|
108 |
+
history.append([text, None])
|
109 |
+
return history, gr.update(value="", interactive=False), gr.update(interactive=False)
|
110 |
+
|
111 |
+
async def _chat(history):
|
112 |
+
prompt = history[-1][0]
|
113 |
+
if len(prompt) == 0:
|
114 |
+
raise gr.Error("Prompt should not be empty.")
|
115 |
+
response = await self.async_run(prompt)
|
116 |
+
history[-1][1] = response.text
|
117 |
+
raw_messages.extend(response.chat_history)
|
118 |
+
return (
|
119 |
+
history,
|
120 |
+
_messages_to_dicts(raw_messages),
|
121 |
+
_messages_to_dicts(self.memory.get_messages()),
|
122 |
+
)
|
123 |
+
|
124 |
+
def _post_chat():
|
125 |
+
return gr.update(interactive=True), gr.update(interactive=True)
|
126 |
+
|
127 |
+
def _clear():
|
128 |
+
raw_messages.clear()
|
129 |
+
self.reset_memory()
|
130 |
+
return None, None, None, None
|
131 |
+
|
132 |
+
def _messages_to_dicts(messages):
|
133 |
+
return [message.to_dict() for message in messages]
|
134 |
+
|
135 |
+
with gr.Blocks(
|
136 |
+
title="ERNIE Bot Agent Demo", theme=gr.themes.Soft(spacing_size="sm", text_size="md")
|
137 |
+
) as demo:
|
138 |
+
with gr.Column():
|
139 |
+
chatbot = gr.Chatbot(
|
140 |
+
label="Chat history",
|
141 |
+
latex_delimiters=[
|
142 |
+
{"left": "$$", "right": "$$", "display": True},
|
143 |
+
{"left": "$", "right": "$", "display": False},
|
144 |
+
],
|
145 |
+
bubble_full_width=False,
|
146 |
+
)
|
147 |
+
prompt_textbox = gr.Textbox(label="Prompt", placeholder="Write a prompt here...")
|
148 |
+
with gr.Row():
|
149 |
+
submit_button = gr.Button("Submit")
|
150 |
+
clear_button = gr.Button("Clear")
|
151 |
+
with gr.Accordion("Tools", open=False):
|
152 |
+
attached_tools = self._tool_manager.get_tools()
|
153 |
+
tool_descriptions = [tool.function_call_schema() for tool in attached_tools]
|
154 |
+
gr.JSON(value=tool_descriptions)
|
155 |
+
with gr.Accordion("Raw messages", open=False):
|
156 |
+
all_messages_json = gr.JSON(label="All messages")
|
157 |
+
agent_memory_json = gr.JSON(label="Messges in memory")
|
158 |
+
prompt_textbox.submit(
|
159 |
+
_pre_chat,
|
160 |
+
inputs=[prompt_textbox, chatbot],
|
161 |
+
outputs=[chatbot, prompt_textbox, submit_button],
|
162 |
+
).then(
|
163 |
+
_chat,
|
164 |
+
inputs=[chatbot],
|
165 |
+
outputs=[
|
166 |
+
chatbot,
|
167 |
+
all_messages_json,
|
168 |
+
agent_memory_json,
|
169 |
+
],
|
170 |
+
).then(
|
171 |
+
_post_chat, outputs=[prompt_textbox, submit_button]
|
172 |
+
)
|
173 |
+
submit_button.click(
|
174 |
+
_pre_chat,
|
175 |
+
inputs=[prompt_textbox, chatbot],
|
176 |
+
outputs=[chatbot, prompt_textbox, submit_button],
|
177 |
+
).then(
|
178 |
+
_chat,
|
179 |
+
inputs=[chatbot],
|
180 |
+
outputs=[
|
181 |
+
chatbot,
|
182 |
+
all_messages_json,
|
183 |
+
agent_memory_json,
|
184 |
+
],
|
185 |
+
).then(
|
186 |
+
_post_chat, outputs=[prompt_textbox, submit_button]
|
187 |
+
)
|
188 |
+
clear_button.click(
|
189 |
+
_clear,
|
190 |
+
outputs=[
|
191 |
+
chatbot,
|
192 |
+
prompt_textbox,
|
193 |
+
all_messages_json,
|
194 |
+
agent_memory_json,
|
195 |
+
],
|
196 |
+
)
|
197 |
+
|
198 |
+
demo.launch(**launch_kwargs)
|
199 |
+
|
200 |
+
@abc.abstractmethod
|
201 |
+
async def _async_run(self, prompt: str) -> AgentResponse:
|
202 |
+
raise NotImplementedError
|
203 |
+
|
204 |
+
async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
|
205 |
+
tool = self._tool_manager.get_tool(tool_name)
|
206 |
+
await self._callback_manager.on_tool_start(agent=self, tool=tool, input_args=tool_args)
|
207 |
+
try:
|
208 |
+
tool_resp = await self._async_run_tool_without_hooks(tool, tool_args)
|
209 |
+
except (Exception, KeyboardInterrupt) as e:
|
210 |
+
await self._callback_manager.on_tool_error(agent=self, tool=tool, error=e)
|
211 |
+
raise
|
212 |
+
await self._callback_manager.on_tool_end(agent=self, tool=tool, response=tool_resp)
|
213 |
+
return tool_resp
|
214 |
+
|
215 |
+
async def _async_run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
|
216 |
+
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages)
|
217 |
+
try:
|
218 |
+
llm_resp = await self._async_run_llm_without_hooks(messages, **opts)
|
219 |
+
except (Exception, KeyboardInterrupt) as e:
|
220 |
+
await self._callback_manager.on_llm_error(agent=self, llm=self.llm, error=e)
|
221 |
+
raise
|
222 |
+
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
|
223 |
+
return llm_resp
|
224 |
+
|
225 |
+
async def _async_run_tool_without_hooks(self, tool: Tool, tool_args: str) -> ToolResponse:
|
226 |
+
bnd_args = self._parse_tool_args(tool, tool_args)
|
227 |
+
# XXX: Sniffing is less efficient and probably unnecessary.
|
228 |
+
# Can we make a protocol to statically recognize file inputs and outputs
|
229 |
+
# or can we have the tools introspect about this?
|
230 |
+
input_files = await self._sniff_and_extract_files_from_args(bnd_args.arguments, tool, "input")
|
231 |
+
tool_ret = await tool(*bnd_args.args, **bnd_args.kwargs)
|
232 |
+
output_files = await self._sniff_and_extract_files_from_args(tool_ret, tool, "output")
|
233 |
+
tool_ret_json = json.dumps(tool_ret, ensure_ascii=False)
|
234 |
+
return ToolResponse(json=tool_ret_json, files=input_files + output_files)
|
235 |
+
|
236 |
+
async def _async_run_llm_without_hooks(
|
237 |
+
self, messages: List[Message], functions=None, **opts: Any
|
238 |
+
) -> LLMResponse:
|
239 |
+
llm_ret = await self.llm.async_chat(messages, functions=functions, stream=False, **opts)
|
240 |
+
return LLMResponse(message=llm_ret)
|
241 |
+
|
242 |
+
def _parse_tool_args(self, tool: Tool, tool_args: str) -> inspect.BoundArguments:
|
243 |
+
args_dict = json.loads(tool_args)
|
244 |
+
if not isinstance(args_dict, dict):
|
245 |
+
raise ValueError("`tool_args` cannot be interpreted as a dict.")
|
246 |
+
# TODO: Check types
|
247 |
+
sig = inspect.signature(tool.__call__)
|
248 |
+
bnd_args = sig.bind(**args_dict)
|
249 |
+
bnd_args.apply_defaults()
|
250 |
+
return bnd_args
|
251 |
+
|
252 |
+
async def _sniff_and_extract_files_from_args(
|
253 |
+
self, args: Dict[str, Any], tool: Tool, file_type: Literal["input", "output"]
|
254 |
+
) -> List[AgentFile]:
|
255 |
+
agent_files: List[AgentFile] = []
|
256 |
+
for val in args.values():
|
257 |
+
if isinstance(val, str):
|
258 |
+
if is_local_file_id(val):
|
259 |
+
if self.file_manager is None:
|
260 |
+
logger.warning(
|
261 |
+
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it."
|
262 |
+
)
|
263 |
+
continue
|
264 |
+
file = self.file_manager.look_up_file_by_id(val)
|
265 |
+
if file is None:
|
266 |
+
raise RuntimeError(f"Unregistered ID {repr(val)} is used by {repr(tool)}.")
|
267 |
+
elif is_remote_file_id(val):
|
268 |
+
if self.file_manager is None:
|
269 |
+
logger.warning(
|
270 |
+
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it."
|
271 |
+
)
|
272 |
+
continue
|
273 |
+
file = self.file_manager.look_up_file_by_id(val)
|
274 |
+
if file is None:
|
275 |
+
file = await self.file_manager.retrieve_remote_file_by_id(val)
|
276 |
+
else:
|
277 |
+
continue
|
278 |
+
agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name))
|
279 |
+
return agent_files
|
erniebot-agent/erniebot_agent/agents/callback/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
erniebot-agent/erniebot_agent/agents/callback/callback_manager.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
from typing import TYPE_CHECKING, Any, List, Union, final
|
19 |
+
|
20 |
+
from erniebot_agent.agents.callback.event import EventType
|
21 |
+
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
|
22 |
+
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
|
23 |
+
from erniebot_agent.chat_models.base import ChatModel
|
24 |
+
from erniebot_agent.messages import Message
|
25 |
+
from erniebot_agent.tools.base import Tool
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from erniebot_agent.agents.base import Agent
|
29 |
+
|
30 |
+
|
31 |
+
@final
|
32 |
+
class CallbackManager(object):
|
33 |
+
def __init__(self, handlers: List[CallbackHandler]):
|
34 |
+
super().__init__()
|
35 |
+
self._handlers = handlers
|
36 |
+
|
37 |
+
@property
|
38 |
+
def handlers(self) -> List[CallbackHandler]:
|
39 |
+
return self._handlers
|
40 |
+
|
41 |
+
def add_handler(self, handler: CallbackHandler):
|
42 |
+
if handler in self._handlers:
|
43 |
+
raise RuntimeError(f"The callback handler {handler} is already registered.")
|
44 |
+
self._handlers.append(handler)
|
45 |
+
|
46 |
+
def remove_handler(self, handler):
|
47 |
+
try:
|
48 |
+
self._handlers.remove(handler)
|
49 |
+
except ValueError as e:
|
50 |
+
raise RuntimeError(f"The callback handler {handler} is not registered.") from e
|
51 |
+
|
52 |
+
def set_handlers(self, handlers: List[CallbackHandler]):
|
53 |
+
self._handlers = []
|
54 |
+
for handler in handlers:
|
55 |
+
self.add_handler(handler)
|
56 |
+
|
57 |
+
def remove_all_handlers(self):
|
58 |
+
self._handlers = []
|
59 |
+
|
60 |
+
async def handle_event(self, event_type: EventType, *args: Any, **kwargs: Any) -> None:
|
61 |
+
callback_name = "on_" + event_type.value
|
62 |
+
for handler in self._handlers:
|
63 |
+
callback = getattr(handler, callback_name, None)
|
64 |
+
if not inspect.iscoroutinefunction(callback):
|
65 |
+
raise TypeError("Callback must be a coroutine function.")
|
66 |
+
await callback(*args, **kwargs)
|
67 |
+
|
68 |
+
async def on_run_start(self, agent: Agent, prompt: str) -> None:
|
69 |
+
await self.handle_event(EventType.RUN_START, agent=agent, prompt=prompt)
|
70 |
+
|
71 |
+
async def on_llm_start(self, agent: Agent, llm: ChatModel, messages: List[Message]) -> None:
|
72 |
+
await self.handle_event(EventType.LLM_START, agent=agent, llm=llm, messages=messages)
|
73 |
+
|
74 |
+
async def on_llm_end(self, agent: Agent, llm: ChatModel, response: LLMResponse) -> None:
|
75 |
+
await self.handle_event(EventType.LLM_END, agent=agent, llm=llm, response=response)
|
76 |
+
|
77 |
+
async def on_llm_error(
|
78 |
+
self, agent: Agent, llm: ChatModel, error: Union[Exception, KeyboardInterrupt]
|
79 |
+
) -> None:
|
80 |
+
await self.handle_event(EventType.LLM_ERROR, agent=agent, llm=llm, error=error)
|
81 |
+
|
82 |
+
async def on_tool_start(self, agent: Agent, tool: Tool, input_args: str) -> None:
|
83 |
+
await self.handle_event(EventType.TOOL_START, agent=agent, tool=tool, input_args=input_args)
|
84 |
+
|
85 |
+
async def on_tool_end(self, agent: Agent, tool: Tool, response: ToolResponse) -> None:
|
86 |
+
await self.handle_event(EventType.TOOL_END, agent=agent, tool=tool, response=response)
|
87 |
+
|
88 |
+
async def on_tool_error(
|
89 |
+
self, agent: Agent, tool: Tool, error: Union[Exception, KeyboardInterrupt]
|
90 |
+
) -> None:
|
91 |
+
await self.handle_event(EventType.TOOL_ERROR, agent=agent, tool=tool, error=error)
|
92 |
+
|
93 |
+
async def on_run_end(self, agent: Agent, response: AgentResponse) -> None:
|
94 |
+
await self.handle_event(EventType.RUN_END, agent=agent, response=response)
|
erniebot-agent/erniebot_agent/agents/callback/default.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import List
|
16 |
+
|
17 |
+
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
|
18 |
+
from erniebot_agent.agents.callback.handlers.logging_handler import LoggingHandler
|
19 |
+
|
20 |
+
|
21 |
+
def get_default_callbacks() -> List[CallbackHandler]:
|
22 |
+
return [LoggingHandler()]
|
erniebot-agent/erniebot_agent/agents/callback/event.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import enum
|
16 |
+
|
17 |
+
|
18 |
+
class EventType(enum.Enum):
|
19 |
+
RUN_START = "run_start"
|
20 |
+
LLM_START = "llm_start"
|
21 |
+
LLM_END = "llm_end"
|
22 |
+
LLM_ERROR = "llm_error"
|
23 |
+
TOOL_START = "tool_start"
|
24 |
+
TOOL_END = "tool_end"
|
25 |
+
TOOL_ERROR = "tool_error"
|
26 |
+
RUN_END = "run_end"
|
erniebot-agent/erniebot_agent/agents/callback/handlers/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
erniebot-agent/erniebot_agent/agents/callback/handlers/base.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
from typing import TYPE_CHECKING, List, Union
|
18 |
+
|
19 |
+
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
|
20 |
+
from erniebot_agent.chat_models.base import ChatModel
|
21 |
+
from erniebot_agent.messages import Message
|
22 |
+
from erniebot_agent.tools.base import Tool
|
23 |
+
|
24 |
+
if TYPE_CHECKING:
|
25 |
+
from erniebot_agent.agents.base import Agent
|
26 |
+
|
27 |
+
|
28 |
+
class CallbackHandler(object):
|
29 |
+
async def on_run_start(self, agent: Agent, prompt: str) -> None:
|
30 |
+
""""""
|
31 |
+
|
32 |
+
async def on_llm_start(self, agent: Agent, llm: ChatModel, messages: List[Message]) -> None:
|
33 |
+
""""""
|
34 |
+
|
35 |
+
async def on_llm_end(self, agent: Agent, llm: ChatModel, response: LLMResponse) -> None:
|
36 |
+
""""""
|
37 |
+
|
38 |
+
async def on_llm_error(
|
39 |
+
self, agent: Agent, llm: ChatModel, error: Union[Exception, KeyboardInterrupt]
|
40 |
+
) -> None:
|
41 |
+
""""""
|
42 |
+
|
43 |
+
async def on_tool_start(self, agent: Agent, tool: Tool, input_args: str) -> None:
|
44 |
+
""""""
|
45 |
+
|
46 |
+
async def on_tool_end(self, agent: Agent, tool: Tool, response: ToolResponse) -> None:
|
47 |
+
""""""
|
48 |
+
|
49 |
+
async def on_tool_error(
|
50 |
+
self, agent: Agent, tool: Tool, error: Union[Exception, KeyboardInterrupt]
|
51 |
+
) -> None:
|
52 |
+
""""""
|
53 |
+
|
54 |
+
async def on_run_end(self, agent: Agent, response: AgentResponse) -> None:
|
55 |
+
""""""
|
erniebot-agent/erniebot_agent/agents/callback/handlers/logging_handler.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import logging
|
18 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
19 |
+
|
20 |
+
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
|
21 |
+
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
|
22 |
+
from erniebot_agent.chat_models.base import ChatModel
|
23 |
+
from erniebot_agent.messages import Message
|
24 |
+
from erniebot_agent.tools.base import Tool
|
25 |
+
from erniebot_agent.utils.json import to_pretty_json
|
26 |
+
from erniebot_agent.utils.logging import logger as default_logger
|
27 |
+
|
28 |
+
if TYPE_CHECKING:
|
29 |
+
from erniebot_agent.agents.base import Agent
|
30 |
+
|
31 |
+
|
32 |
+
class LoggingHandler(CallbackHandler):
|
33 |
+
logger: logging.Logger
|
34 |
+
|
35 |
+
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
|
36 |
+
super().__init__()
|
37 |
+
if logger is None:
|
38 |
+
self.logger = default_logger
|
39 |
+
else:
|
40 |
+
self.logger = logger
|
41 |
+
|
42 |
+
async def on_run_start(self, agent: Agent, prompt: str) -> None:
|
43 |
+
self.agent_info(
|
44 |
+
"%s is about to start running with input: %s\n",
|
45 |
+
agent.__class__.__name__,
|
46 |
+
prompt,
|
47 |
+
subject="Run",
|
48 |
+
state="Start",
|
49 |
+
)
|
50 |
+
|
51 |
+
async def on_llm_start(self, agent: Agent, llm: ChatModel, messages: List[Message]) -> None:
|
52 |
+
# TODO: Prettier messages
|
53 |
+
self.agent_info(
|
54 |
+
"%s is about to start running with input:\n%s\n",
|
55 |
+
llm.__class__.__name__,
|
56 |
+
messages,
|
57 |
+
subject="LLM",
|
58 |
+
state="Start",
|
59 |
+
)
|
60 |
+
|
61 |
+
async def on_llm_end(self, agent: Agent, llm: ChatModel, response: LLMResponse) -> None:
|
62 |
+
self.agent_info(
|
63 |
+
"%s finished running with output: %s\n",
|
64 |
+
llm.__class__.__name__,
|
65 |
+
response.message,
|
66 |
+
subject="LLM",
|
67 |
+
state="End",
|
68 |
+
)
|
69 |
+
|
70 |
+
async def on_llm_error(
|
71 |
+
self, agent: Agent, llm: ChatModel, error: Union[Exception, KeyboardInterrupt]
|
72 |
+
) -> None:
|
73 |
+
pass
|
74 |
+
|
75 |
+
async def on_tool_start(self, agent: Agent, tool: Tool, input_args: str) -> None:
|
76 |
+
self.agent_info(
|
77 |
+
"%s is about to start running with input:\n%s\n",
|
78 |
+
tool.__class__.__name__,
|
79 |
+
to_pretty_json(input_args, from_json=True),
|
80 |
+
subject="Tool",
|
81 |
+
state="Start",
|
82 |
+
)
|
83 |
+
|
84 |
+
async def on_tool_end(self, agent: Agent, tool: Tool, response: ToolResponse) -> None:
|
85 |
+
self.agent_info(
|
86 |
+
"%s finished running with output:\n%s\n",
|
87 |
+
tool.__class__.__name__,
|
88 |
+
to_pretty_json(response.json, from_json=True),
|
89 |
+
subject="Tool",
|
90 |
+
state="End",
|
91 |
+
)
|
92 |
+
|
93 |
+
async def on_tool_error(
|
94 |
+
self, agent: Agent, tool: Tool, error: Union[Exception, KeyboardInterrupt]
|
95 |
+
) -> None:
|
96 |
+
pass
|
97 |
+
|
98 |
+
async def on_run_end(self, agent: Agent, response: AgentResponse) -> None:
|
99 |
+
self.agent_info("%s finished running.\n", agent.__class__.__name__, subject="Run", state="End")
|
100 |
+
|
101 |
+
def agent_info(self, msg: str, *args, subject, state, **kwargs) -> None:
|
102 |
+
msg = f"[{subject}][{state}] {msg}"
|
103 |
+
self.logger.info(msg, *args, **kwargs)
|
104 |
+
|
105 |
+
def agent_error(self, error: Union[Exception, KeyboardInterrupt], *args, subject, **kwargs) -> None:
|
106 |
+
error_msg = f"[{subject}][ERROR] {error}"
|
107 |
+
self.logger.error(error_msg, *args, **kwargs)
|
erniebot-agent/erniebot_agent/agents/functional_agent.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import List, Optional, Union
|
16 |
+
|
17 |
+
from erniebot_agent.agents.base import Agent, ToolManager
|
18 |
+
from erniebot_agent.agents.callback.callback_manager import CallbackManager
|
19 |
+
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
|
20 |
+
from erniebot_agent.agents.schema import AgentAction, AgentFile, AgentResponse
|
21 |
+
from erniebot_agent.chat_models.base import ChatModel
|
22 |
+
from erniebot_agent.file_io.file_manager import FileManager
|
23 |
+
from erniebot_agent.memory.base import Memory
|
24 |
+
from erniebot_agent.messages import (
|
25 |
+
FunctionMessage,
|
26 |
+
HumanMessage,
|
27 |
+
Message,
|
28 |
+
SystemMessage,
|
29 |
+
)
|
30 |
+
from erniebot_agent.tools.base import Tool
|
31 |
+
|
32 |
+
_MAX_STEPS = 5
|
33 |
+
|
34 |
+
|
35 |
+
class FunctionalAgent(Agent):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
llm: ChatModel,
|
39 |
+
tools: Union[ToolManager, List[Tool]],
|
40 |
+
memory: Memory,
|
41 |
+
system_message: Optional[SystemMessage] = None,
|
42 |
+
*,
|
43 |
+
callbacks: Optional[Union[CallbackManager, List[CallbackHandler]]] = None,
|
44 |
+
file_manager: Optional[FileManager] = None,
|
45 |
+
max_steps: Optional[int] = None,
|
46 |
+
) -> None:
|
47 |
+
super().__init__(
|
48 |
+
llm=llm,
|
49 |
+
tools=tools,
|
50 |
+
memory=memory,
|
51 |
+
system_message=system_message,
|
52 |
+
callbacks=callbacks,
|
53 |
+
file_manager=file_manager,
|
54 |
+
)
|
55 |
+
if max_steps is not None:
|
56 |
+
if max_steps <= 0:
|
57 |
+
raise ValueError("Invalid `max_steps` value")
|
58 |
+
self.max_steps = max_steps
|
59 |
+
else:
|
60 |
+
self.max_steps = _MAX_STEPS
|
61 |
+
|
62 |
+
async def _async_run(self, prompt: str) -> AgentResponse:
|
63 |
+
chat_history: List[Message] = []
|
64 |
+
actions_taken: List[AgentAction] = []
|
65 |
+
files_involved: List[AgentFile] = []
|
66 |
+
ask = HumanMessage(content=prompt)
|
67 |
+
|
68 |
+
num_steps_taken = 0
|
69 |
+
next_step_input: Message = ask
|
70 |
+
while num_steps_taken < self.max_steps:
|
71 |
+
curr_step_output = await self._async_step(
|
72 |
+
next_step_input, chat_history, actions_taken, files_involved
|
73 |
+
)
|
74 |
+
if curr_step_output is None:
|
75 |
+
response = self._create_finished_response(chat_history, actions_taken, files_involved)
|
76 |
+
self.memory.add_message(chat_history[0])
|
77 |
+
self.memory.add_message(chat_history[-1])
|
78 |
+
return response
|
79 |
+
num_steps_taken += 1
|
80 |
+
next_step_input = curr_step_output
|
81 |
+
response = self._create_stopped_response(chat_history, actions_taken, files_involved)
|
82 |
+
return response
|
83 |
+
|
84 |
+
async def _async_step(
|
85 |
+
self,
|
86 |
+
step_input,
|
87 |
+
chat_history: List[Message],
|
88 |
+
actions: List[AgentAction],
|
89 |
+
files: List[AgentFile],
|
90 |
+
) -> Optional[Message]:
|
91 |
+
maybe_action = await self._async_plan(step_input, chat_history)
|
92 |
+
if isinstance(maybe_action, AgentAction):
|
93 |
+
action: AgentAction = maybe_action
|
94 |
+
tool_resp = await self._async_run_tool(tool_name=action.tool_name, tool_args=action.tool_args)
|
95 |
+
actions.append(action)
|
96 |
+
files.extend(tool_resp.files)
|
97 |
+
return FunctionMessage(name=action.tool_name, content=tool_resp.json)
|
98 |
+
else:
|
99 |
+
return None
|
100 |
+
|
101 |
+
async def _async_plan(
|
102 |
+
self, input_message: Message, chat_history: List[Message]
|
103 |
+
) -> Optional[AgentAction]:
|
104 |
+
chat_history.append(input_message)
|
105 |
+
messages = self.memory.get_messages() + chat_history
|
106 |
+
llm_resp = await self._async_run_llm(
|
107 |
+
messages=messages,
|
108 |
+
functions=self._tool_manager.get_tool_schemas(),
|
109 |
+
system=self.system_message.content if self.system_message is not None else None,
|
110 |
+
)
|
111 |
+
output_message = llm_resp.message
|
112 |
+
chat_history.append(output_message)
|
113 |
+
if output_message.function_call is not None:
|
114 |
+
return AgentAction(
|
115 |
+
tool_name=output_message.function_call["name"], # type: ignore
|
116 |
+
tool_args=output_message.function_call["arguments"],
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
return None
|
120 |
+
|
121 |
+
def _create_finished_response(
|
122 |
+
self,
|
123 |
+
chat_history: List[Message],
|
124 |
+
actions: List[AgentAction],
|
125 |
+
files: List[AgentFile],
|
126 |
+
) -> AgentResponse:
|
127 |
+
last_message = chat_history[-1]
|
128 |
+
return AgentResponse(
|
129 |
+
text=last_message.content,
|
130 |
+
chat_history=chat_history,
|
131 |
+
actions=actions,
|
132 |
+
files=files,
|
133 |
+
status="FINISHED",
|
134 |
+
)
|
135 |
+
|
136 |
+
def _create_stopped_response(
|
137 |
+
self,
|
138 |
+
chat_history: List[Message],
|
139 |
+
actions: List[AgentAction],
|
140 |
+
files: List[AgentFile],
|
141 |
+
) -> AgentResponse:
|
142 |
+
return AgentResponse(
|
143 |
+
text="Agent run stopped early.",
|
144 |
+
chat_history=chat_history,
|
145 |
+
actions=actions,
|
146 |
+
files=files,
|
147 |
+
status="STOPPED",
|
148 |
+
)
|
erniebot-agent/erniebot_agent/agents/schema.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
from erniebot_agent.file_io.base import File
|
19 |
+
from erniebot_agent.messages import AIMessage, Message
|
20 |
+
from typing_extensions import Literal
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class AgentAction(object):
|
25 |
+
"""An action for an agent to execute."""
|
26 |
+
|
27 |
+
tool_name: str
|
28 |
+
tool_args: str
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class AgentPlan(object):
|
33 |
+
"""A plan that contains a list of actions."""
|
34 |
+
|
35 |
+
actions: List[AgentAction]
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class LLMResponse(object):
|
40 |
+
"""A response from an LLM."""
|
41 |
+
|
42 |
+
message: AIMessage
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class ToolResponse(object):
|
47 |
+
"""A response from a tool."""
|
48 |
+
|
49 |
+
json: str
|
50 |
+
files: List["AgentFile"]
|
51 |
+
|
52 |
+
|
53 |
+
@dataclass
|
54 |
+
class AgentResponse(object):
|
55 |
+
"""The final response from an agent."""
|
56 |
+
|
57 |
+
text: str
|
58 |
+
chat_history: List[Message]
|
59 |
+
actions: List[AgentAction]
|
60 |
+
files: List["AgentFile"]
|
61 |
+
status: Union[Literal["FINISHED"], Literal["STOPPED"]]
|
62 |
+
|
63 |
+
def get_last_output_file(self) -> Optional[File]:
|
64 |
+
for agent_file in self.files[::-1]:
|
65 |
+
if agent_file.type == "output":
|
66 |
+
return agent_file.file
|
67 |
+
else:
|
68 |
+
return None
|
69 |
+
|
70 |
+
def get_output_files(self) -> List[File]:
|
71 |
+
return [agent_file.file for agent_file in self.files if agent_file.type == "output"]
|
72 |
+
|
73 |
+
def get_tool_input_output_files(self, tool_name: str) -> Tuple[List[File], List[File]]:
|
74 |
+
input_files: List[File] = []
|
75 |
+
output_files: List[File] = []
|
76 |
+
for agent_file in self.files:
|
77 |
+
if agent_file.used_by == tool_name:
|
78 |
+
if agent_file.type == "input":
|
79 |
+
input_files.append(agent_file.file)
|
80 |
+
elif agent_file.type == "output":
|
81 |
+
output_files.append(agent_file.file)
|
82 |
+
else:
|
83 |
+
raise RuntimeError("File type is neither input nor output.")
|
84 |
+
return input_files, output_files
|
85 |
+
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class AgentFile(object):
|
89 |
+
"""A file that is used by an agent."""
|
90 |
+
|
91 |
+
file: File
|
92 |
+
type: Literal["input", "output"]
|
93 |
+
used_by: str
|
erniebot-agent/erniebot_agent/chat_models/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .erniebot import ERNIEBot
|
16 |
+
|
17 |
+
__all__ = ["ERNIEBot"]
|
erniebot-agent/erniebot_agent/chat_models/base.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License"
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABCMeta, abstractmethod
|
16 |
+
from typing import Any, AsyncIterator, List, Literal, Union, overload
|
17 |
+
|
18 |
+
from erniebot_agent.messages import AIMessage, AIMessageChunk, Message
|
19 |
+
|
20 |
+
|
21 |
+
class ChatModel(metaclass=ABCMeta):
|
22 |
+
"""The base class of chat-optimized LLM."""
|
23 |
+
|
24 |
+
def __init__(self, model: str):
|
25 |
+
self.model = model
|
26 |
+
|
27 |
+
@overload
|
28 |
+
async def async_chat(
|
29 |
+
self, messages: List[Message], *, stream: Literal[False] = ..., **kwargs: Any
|
30 |
+
) -> AIMessage:
|
31 |
+
...
|
32 |
+
|
33 |
+
@overload
|
34 |
+
async def async_chat(
|
35 |
+
self, messages: List[Message], *, stream: Literal[True], **kwargs: Any
|
36 |
+
) -> AsyncIterator[AIMessageChunk]:
|
37 |
+
...
|
38 |
+
|
39 |
+
@overload
|
40 |
+
async def async_chat(
|
41 |
+
self, messages: List[Message], *, stream: bool, **kwargs: Any
|
42 |
+
) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
|
43 |
+
...
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
async def async_chat(
|
47 |
+
self, messages: List[Message], *, stream: bool = False, **kwargs: Any
|
48 |
+
) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
|
49 |
+
"""Asynchronously chats with the LLM.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
messages (List[Message]): A list of messages.
|
53 |
+
stream (bool): Whether to use streaming generation. Defaults to False.
|
54 |
+
**kwargs: Arbitrary keyword arguments.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
If stream is False, returns a single message.
|
58 |
+
If stream is True, returns an asynchronous iterator of message chunks.
|
59 |
+
"""
|
60 |
+
raise NotImplementedError
|
erniebot-agent/erniebot_agent/chat_models/erniebot.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License"
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import (
|
16 |
+
Any,
|
17 |
+
AsyncIterator,
|
18 |
+
Dict,
|
19 |
+
List,
|
20 |
+
Literal,
|
21 |
+
Optional,
|
22 |
+
Type,
|
23 |
+
TypeVar,
|
24 |
+
Union,
|
25 |
+
overload,
|
26 |
+
)
|
27 |
+
|
28 |
+
from erniebot_agent.chat_models.base import ChatModel
|
29 |
+
from erniebot_agent.messages import AIMessage, AIMessageChunk, FunctionCall, Message
|
30 |
+
|
31 |
+
import erniebot
|
32 |
+
from erniebot.response import EBResponse
|
33 |
+
|
34 |
+
_T = TypeVar("_T", AIMessage, AIMessageChunk)
|
35 |
+
|
36 |
+
|
37 |
+
class ERNIEBot(ChatModel):
|
38 |
+
def __init__(
|
39 |
+
self, model: str, api_type: Optional[str] = None, access_token: Optional[str] = None
|
40 |
+
) -> None:
|
41 |
+
"""Initializes an instance of the `ERNIEBot` class.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
model (str): The model name. It should be "ernie-bot", "ernie-bot-turbo", "ernie-bot-8k", or
|
45 |
+
"ernie-bot-4".
|
46 |
+
api_type (Optional[str]): The API type for erniebot. It should be "aistudio" or "qianfan".
|
47 |
+
access_token (Optional[str]): The access token for erniebot.
|
48 |
+
"""
|
49 |
+
super().__init__(model=model)
|
50 |
+
self.api_type = api_type
|
51 |
+
self.access_token = access_token
|
52 |
+
|
53 |
+
@overload
|
54 |
+
async def async_chat(
|
55 |
+
self,
|
56 |
+
messages: List[Message],
|
57 |
+
*,
|
58 |
+
stream: Literal[False] = ...,
|
59 |
+
functions: Optional[List[dict]] = ...,
|
60 |
+
**kwargs: Any,
|
61 |
+
) -> AIMessage:
|
62 |
+
...
|
63 |
+
|
64 |
+
@overload
|
65 |
+
async def async_chat(
|
66 |
+
self,
|
67 |
+
messages: List[Message],
|
68 |
+
*,
|
69 |
+
stream: Literal[True],
|
70 |
+
functions: Optional[List[dict]] = ...,
|
71 |
+
**kwargs: Any,
|
72 |
+
) -> AsyncIterator[AIMessageChunk]:
|
73 |
+
...
|
74 |
+
|
75 |
+
@overload
|
76 |
+
async def async_chat(
|
77 |
+
self, messages: List[Message], *, stream: bool, functions: Optional[List[dict]] = ..., **kwargs: Any
|
78 |
+
) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
|
79 |
+
...
|
80 |
+
|
81 |
+
async def async_chat(
|
82 |
+
self,
|
83 |
+
messages: List[Message],
|
84 |
+
*,
|
85 |
+
stream: bool = False,
|
86 |
+
functions: Optional[List[dict]] = None,
|
87 |
+
**kwargs: Any,
|
88 |
+
) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
|
89 |
+
"""Asynchronously chats with the ERNIE Bot model.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
messages (List[Message]): A list of messages.
|
93 |
+
stream (bool): Whether to use streaming generation. Defaults to False.
|
94 |
+
functions (Optional[List[dict]]): The function definitions to be used by the model.
|
95 |
+
Defaults to None.
|
96 |
+
**kwargs: Keyword arguments, such as `top_p`, `temperature`, `penalty_score`, and `system`.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
If `stream` is False, returns a single message.
|
100 |
+
If `stream` is True, returns an asynchronous iterator of message chunks.
|
101 |
+
"""
|
102 |
+
cfg_dict: Dict[str, Any] = {"model": self.model, "_config_": {}}
|
103 |
+
if self.api_type is not None:
|
104 |
+
cfg_dict["_config_"]["api_type"] = self.api_type
|
105 |
+
if self.access_token is not None:
|
106 |
+
cfg_dict["_config_"]["access_token"] = self.access_token
|
107 |
+
|
108 |
+
# TODO: process system message
|
109 |
+
cfg_dict["messages"] = [m.to_dict() for m in messages]
|
110 |
+
if functions is not None:
|
111 |
+
cfg_dict["functions"] = functions
|
112 |
+
|
113 |
+
name_list = ["top_p", "temperature", "penalty_score", "system"]
|
114 |
+
for name in name_list:
|
115 |
+
if name in kwargs:
|
116 |
+
cfg_dict[name] = kwargs[name]
|
117 |
+
|
118 |
+
# TODO: Improve this when erniebot typing issue is fixed.
|
119 |
+
response: Any = await erniebot.ChatCompletion.acreate(stream=stream, **cfg_dict)
|
120 |
+
if isinstance(response, EBResponse):
|
121 |
+
return self.convert_response_to_output(response, AIMessage)
|
122 |
+
else:
|
123 |
+
return (self.convert_response_to_output(resp, AIMessageChunk) async for resp in response)
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def convert_response_to_output(response: EBResponse, output_type: Type[_T]) -> _T:
|
127 |
+
if hasattr(response, "function_call"):
|
128 |
+
function_call = FunctionCall(
|
129 |
+
name=response.function_call["name"],
|
130 |
+
thoughts=response.function_call["thoughts"],
|
131 |
+
arguments=response.function_call["arguments"],
|
132 |
+
)
|
133 |
+
return output_type(content="", function_call=function_call, token_usage=response.usage)
|
134 |
+
else:
|
135 |
+
return output_type(content=response.result, function_call=None, token_usage=response.usage)
|
erniebot-agent/erniebot_agent/extensions/langchain/chat_models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .erniebot import ErnieBotChat
|
erniebot-agent/erniebot_agent/extensions/langchain/chat_models/erniebot.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import (
|
4 |
+
Any,
|
5 |
+
AsyncIterator,
|
6 |
+
Callable,
|
7 |
+
Dict,
|
8 |
+
Iterator,
|
9 |
+
List,
|
10 |
+
Mapping,
|
11 |
+
Optional,
|
12 |
+
Type,
|
13 |
+
Union,
|
14 |
+
)
|
15 |
+
|
16 |
+
from langchain.callbacks.manager import (
|
17 |
+
AsyncCallbackManagerForLLMRun,
|
18 |
+
CallbackManagerForLLMRun,
|
19 |
+
)
|
20 |
+
from langchain.chat_models.base import BaseChatModel
|
21 |
+
from langchain.llms.base import create_base_retry_decorator
|
22 |
+
from langchain.pydantic_v1 import Field, root_validator
|
23 |
+
from langchain.schema import ChatGeneration, ChatResult
|
24 |
+
from langchain.schema.messages import (
|
25 |
+
AIMessage,
|
26 |
+
AIMessageChunk,
|
27 |
+
BaseMessage,
|
28 |
+
ChatMessage,
|
29 |
+
FunctionMessage,
|
30 |
+
HumanMessage,
|
31 |
+
SystemMessage,
|
32 |
+
)
|
33 |
+
from langchain.schema.output import ChatGenerationChunk
|
34 |
+
from langchain.utils import get_from_dict_or_env
|
35 |
+
|
36 |
+
_MessageDict = Dict[str, Any]
|
37 |
+
|
38 |
+
|
39 |
+
class ErnieBotChat(BaseChatModel):
|
40 |
+
"""ERNIE Bot Chat large language models API.
|
41 |
+
|
42 |
+
To use, you should have the ``erniebot`` python package installed, and the
|
43 |
+
environment variable ``EB_ACCESS_TOKEN`` set with your AI Studio access token.
|
44 |
+
|
45 |
+
Example:
|
46 |
+
.. code-block:: python
|
47 |
+
from erniebot_agent.extensions.langchain.chat_models import ErnieBotChat
|
48 |
+
erniebot_chat = ErnieBotChat(model="ernie-bot")
|
49 |
+
"""
|
50 |
+
|
51 |
+
client: Any = None
|
52 |
+
max_retries: int = 6
|
53 |
+
"""Maximum number of retries to make when generating."""
|
54 |
+
aistudio_access_token: Optional[str] = None
|
55 |
+
"""AI Studio access token."""
|
56 |
+
streaming: Optional[bool] = False
|
57 |
+
"""Whether to stream the results or not."""
|
58 |
+
model: str = "ernie-bot"
|
59 |
+
"""Model to use."""
|
60 |
+
top_p: Optional[float] = 0.8
|
61 |
+
"""Parameter of nucleus sampling that affects the diversity of generated content."""
|
62 |
+
temperature: Optional[float] = 0.95
|
63 |
+
"""Sampling temperature to use."""
|
64 |
+
penalty_score: Optional[float] = 1
|
65 |
+
"""Penalty assigned to tokens that have been generated."""
|
66 |
+
request_timeout: Optional[int] = 60
|
67 |
+
"""How many seconds to wait for the server to send data before giving up."""
|
68 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
69 |
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
70 |
+
|
71 |
+
ernie_client_id: Optional[str] = None
|
72 |
+
ernie_client_secret: Optional[str] = None
|
73 |
+
"""For raising deprecation warnings."""
|
74 |
+
|
75 |
+
@property
|
76 |
+
def _default_params(self) -> Dict[str, Any]:
|
77 |
+
"""Get the default parameters for calling ERNIE Bot API."""
|
78 |
+
normal_params = {
|
79 |
+
"model": self.model,
|
80 |
+
"top_p": self.top_p,
|
81 |
+
"temperature": self.temperature,
|
82 |
+
"penalty_score": self.penalty_score,
|
83 |
+
"request_timeout": self.request_timeout,
|
84 |
+
}
|
85 |
+
return {**normal_params, **self.model_kwargs}
|
86 |
+
|
87 |
+
@property
|
88 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
89 |
+
return self._default_params
|
90 |
+
|
91 |
+
@property
|
92 |
+
def _invocation_params(self) -> Dict[str, Any]:
|
93 |
+
"""Get the parameters used to invoke the model."""
|
94 |
+
auth_cfg: Dict[str, Optional[str]] = {
|
95 |
+
"api_type": "aistudio",
|
96 |
+
"access_token": self.aistudio_access_token,
|
97 |
+
}
|
98 |
+
return {**{"_config_": auth_cfg}, **self._default_params}
|
99 |
+
|
100 |
+
@property
|
101 |
+
def _llm_type(self) -> str:
|
102 |
+
"""Return type of llm."""
|
103 |
+
return "erniebot"
|
104 |
+
|
105 |
+
@root_validator()
|
106 |
+
def validate_enviroment(cls, values: Dict) -> Dict:
|
107 |
+
values["aistudio_access_token"] = get_from_dict_or_env(
|
108 |
+
values,
|
109 |
+
"aistudio_access_token",
|
110 |
+
"EB_ACCESS_TOKEN",
|
111 |
+
)
|
112 |
+
|
113 |
+
try:
|
114 |
+
import erniebot
|
115 |
+
|
116 |
+
values["client"] = erniebot.ChatCompletion
|
117 |
+
except ImportError:
|
118 |
+
raise ImportError(
|
119 |
+
"Could not import erniebot python package. Please install it with `pip install erniebot`."
|
120 |
+
)
|
121 |
+
return values
|
122 |
+
|
123 |
+
def _generate(
|
124 |
+
self,
|
125 |
+
messages: List[BaseMessage],
|
126 |
+
stop: Optional[List[str]] = None,
|
127 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
128 |
+
**kwargs: Any,
|
129 |
+
) -> ChatResult:
|
130 |
+
if self.streaming:
|
131 |
+
chunks = self._stream(messages, stop=stop, run_manager=run_manager, **kwargs)
|
132 |
+
generation: Optional[ChatGenerationChunk] = None
|
133 |
+
for chunk in chunks:
|
134 |
+
if generation is None:
|
135 |
+
generation = chunk
|
136 |
+
else:
|
137 |
+
generation += chunk
|
138 |
+
assert generation is not None
|
139 |
+
return ChatResult(generations=[generation])
|
140 |
+
else:
|
141 |
+
params = self._invocation_params
|
142 |
+
params.update(kwargs)
|
143 |
+
params["messages"] = self._convert_messages_to_dicts(messages)
|
144 |
+
system_prompt = self._build_system_prompt_from_messages(messages)
|
145 |
+
if system_prompt is not None:
|
146 |
+
params["system"] = system_prompt
|
147 |
+
params["stream"] = False
|
148 |
+
response = _create_completion_with_retry(self, run_manager=run_manager, **params)
|
149 |
+
return self._build_chat_result_from_response(response)
|
150 |
+
|
151 |
+
async def _agenerate(
|
152 |
+
self,
|
153 |
+
messages: List[BaseMessage],
|
154 |
+
stop: Optional[List[str]] = None,
|
155 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
156 |
+
**kwargs: Any,
|
157 |
+
) -> ChatResult:
|
158 |
+
if self.streaming:
|
159 |
+
chunks = self._astream(messages, stop=stop, run_manager=run_manager, **kwargs)
|
160 |
+
generation: Optional[ChatGenerationChunk] = None
|
161 |
+
async for chunk in chunks:
|
162 |
+
if generation is None:
|
163 |
+
generation = chunk
|
164 |
+
else:
|
165 |
+
generation += chunk
|
166 |
+
assert generation is not None
|
167 |
+
return ChatResult(generations=[generation])
|
168 |
+
else:
|
169 |
+
params = self._invocation_params
|
170 |
+
params.update(kwargs)
|
171 |
+
params["messages"] = self._convert_messages_to_dicts(messages)
|
172 |
+
system_prompt = self._build_system_prompt_from_messages(messages)
|
173 |
+
if system_prompt is not None:
|
174 |
+
params["system"] = system_prompt
|
175 |
+
params["stream"] = False
|
176 |
+
response = await _acreate_completion_with_retry(self, run_manager=run_manager, **params)
|
177 |
+
return self._build_chat_result_from_response(response)
|
178 |
+
|
179 |
+
def _stream(
|
180 |
+
self,
|
181 |
+
messages: List[BaseMessage],
|
182 |
+
stop: Optional[List[str]] = None,
|
183 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
184 |
+
**kwargs: Any,
|
185 |
+
) -> Iterator[ChatGenerationChunk]:
|
186 |
+
if stop is not None:
|
187 |
+
raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
|
188 |
+
params = self._invocation_params
|
189 |
+
params.update(kwargs)
|
190 |
+
params["messages"] = self._convert_messages_to_dicts(messages)
|
191 |
+
system_prompt = self._build_system_prompt_from_messages(messages)
|
192 |
+
if system_prompt is not None:
|
193 |
+
params["system"] = system_prompt
|
194 |
+
params["stream"] = True
|
195 |
+
for resp in _create_completion_with_retry(self, run_manager=run_manager, **params):
|
196 |
+
chunk = self._build_chunk_from_response(resp)
|
197 |
+
yield chunk
|
198 |
+
if run_manager:
|
199 |
+
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
200 |
+
|
201 |
+
async def _astream(
|
202 |
+
self,
|
203 |
+
messages: List[BaseMessage],
|
204 |
+
stop: Optional[List[str]] = None,
|
205 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
206 |
+
**kwargs: Any,
|
207 |
+
) -> AsyncIterator[ChatGenerationChunk]:
|
208 |
+
if stop is not None:
|
209 |
+
raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
|
210 |
+
params = self._invocation_params
|
211 |
+
params.update(kwargs)
|
212 |
+
params["messages"] = self._convert_messages_to_dicts(messages)
|
213 |
+
system_prompt = self._build_system_prompt_from_messages(messages)
|
214 |
+
if system_prompt is not None:
|
215 |
+
params["system"] = system_prompt
|
216 |
+
params["stream"] = True
|
217 |
+
async for resp in await _acreate_completion_with_retry(self, run_manager=run_manager, **params):
|
218 |
+
chunk = self._build_chunk_from_response(resp)
|
219 |
+
yield chunk
|
220 |
+
if run_manager:
|
221 |
+
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
222 |
+
|
223 |
+
def _build_chat_result_from_response(self, response: Mapping[str, Any]) -> ChatResult:
|
224 |
+
message_dict = self._build_dict_from_response(response)
|
225 |
+
generation = ChatGeneration(
|
226 |
+
message=self._convert_dict_to_message(message_dict),
|
227 |
+
generation_info=dict(finish_reason="stop"),
|
228 |
+
)
|
229 |
+
token_usage = response.get("usage", {})
|
230 |
+
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
231 |
+
return ChatResult(generations=[generation], llm_output=llm_output)
|
232 |
+
|
233 |
+
def _build_chunk_from_response(self, response: Mapping[str, Any]) -> ChatGenerationChunk:
|
234 |
+
message_dict = self._build_dict_from_response(response)
|
235 |
+
message = self._convert_dict_to_message(message_dict)
|
236 |
+
msg_chunk = AIMessageChunk(
|
237 |
+
content=message.content,
|
238 |
+
additional_kwargs=message.additional_kwargs,
|
239 |
+
)
|
240 |
+
return ChatGenerationChunk(message=msg_chunk)
|
241 |
+
|
242 |
+
def _build_dict_from_response(self, response: Mapping[str, Any]) -> _MessageDict:
|
243 |
+
message_dict: _MessageDict = {"role": "assistant"}
|
244 |
+
if "function_call" in response:
|
245 |
+
message_dict["content"] = None
|
246 |
+
message_dict["function_call"] = response["function_call"]
|
247 |
+
else:
|
248 |
+
message_dict["content"] = response["result"]
|
249 |
+
return message_dict
|
250 |
+
|
251 |
+
def _build_system_prompt_from_messages(self, messages: List[BaseMessage]) -> Optional[str]:
|
252 |
+
system_message_content_list: List[str] = []
|
253 |
+
for msg in messages:
|
254 |
+
if isinstance(msg, SystemMessage):
|
255 |
+
if isinstance(msg.content, str):
|
256 |
+
system_message_content_list.append(msg.content)
|
257 |
+
else:
|
258 |
+
raise TypeError
|
259 |
+
if len(system_message_content_list) > 0:
|
260 |
+
return "\n".join(system_message_content_list)
|
261 |
+
else:
|
262 |
+
return None
|
263 |
+
|
264 |
+
def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> List[dict]:
|
265 |
+
erniebot_messages = []
|
266 |
+
for msg in messages:
|
267 |
+
if isinstance(msg, SystemMessage):
|
268 |
+
# Ignore system messages, as we handle them elsewhere.
|
269 |
+
continue
|
270 |
+
eb_msg = self._convert_message_to_dict(msg)
|
271 |
+
erniebot_messages.append(eb_msg)
|
272 |
+
return erniebot_messages
|
273 |
+
|
274 |
+
@staticmethod
|
275 |
+
def _convert_dict_to_message(message_dict: _MessageDict) -> BaseMessage:
|
276 |
+
role = message_dict["role"]
|
277 |
+
if role == "user":
|
278 |
+
return HumanMessage(content=message_dict["content"])
|
279 |
+
elif role == "assistant":
|
280 |
+
content = message_dict["content"] or ""
|
281 |
+
if message_dict.get("function_call"):
|
282 |
+
additional_kwargs = {"function_call": dict(message_dict["function_call"])}
|
283 |
+
else:
|
284 |
+
additional_kwargs = {}
|
285 |
+
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
286 |
+
elif role == "function":
|
287 |
+
return FunctionMessage(content=message_dict["content"], name=message_dict["name"])
|
288 |
+
else:
|
289 |
+
return ChatMessage(content=message_dict["content"], role=role)
|
290 |
+
|
291 |
+
@staticmethod
|
292 |
+
def _convert_message_to_dict(message: BaseMessage) -> _MessageDict:
|
293 |
+
message_dict: _MessageDict
|
294 |
+
if isinstance(message, ChatMessage):
|
295 |
+
message_dict = {"role": message.role, "content": message.content}
|
296 |
+
elif isinstance(message, HumanMessage):
|
297 |
+
message_dict = {"role": "user", "content": message.content}
|
298 |
+
elif isinstance(message, AIMessage):
|
299 |
+
message_dict = {"role": "assistant", "content": message.content}
|
300 |
+
if "function_call" in message.additional_kwargs:
|
301 |
+
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
302 |
+
if message_dict["content"] == "":
|
303 |
+
message_dict["content"] = None
|
304 |
+
elif isinstance(message, FunctionMessage):
|
305 |
+
message_dict = {
|
306 |
+
"role": "function",
|
307 |
+
"content": message.content,
|
308 |
+
"name": message.name,
|
309 |
+
}
|
310 |
+
else:
|
311 |
+
raise TypeError(f"Got unknown type {message}")
|
312 |
+
|
313 |
+
return message_dict
|
314 |
+
|
315 |
+
|
316 |
+
def _create_completion_with_retry(
|
317 |
+
llm: ErnieBotChat,
|
318 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
319 |
+
**kwargs: Any,
|
320 |
+
) -> Any:
|
321 |
+
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
322 |
+
|
323 |
+
@retry_decorator
|
324 |
+
def _client_create(**kwargs: Any) -> Any:
|
325 |
+
return llm.client.create(**kwargs)
|
326 |
+
|
327 |
+
return _client_create(**kwargs)
|
328 |
+
|
329 |
+
|
330 |
+
async def _acreate_completion_with_retry(
|
331 |
+
llm: ErnieBotChat,
|
332 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
333 |
+
**kwargs: Any,
|
334 |
+
) -> Any:
|
335 |
+
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
336 |
+
|
337 |
+
@retry_decorator
|
338 |
+
async def _client_acreate(**kwargs: Any) -> Any:
|
339 |
+
return await llm.client.acreate(**kwargs)
|
340 |
+
|
341 |
+
return await _client_acreate(**kwargs)
|
342 |
+
|
343 |
+
|
344 |
+
def _create_retry_decorator(
|
345 |
+
llm: ErnieBotChat,
|
346 |
+
run_manager: Optional[Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]] = None,
|
347 |
+
) -> Callable[[Any], Any]:
|
348 |
+
import erniebot
|
349 |
+
|
350 |
+
errors: List[Type[BaseException]] = [
|
351 |
+
erniebot.errors.TimeoutError,
|
352 |
+
erniebot.errors.RequestLimitError,
|
353 |
+
]
|
354 |
+
return create_base_retry_decorator(
|
355 |
+
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
356 |
+
)
|
erniebot-agent/erniebot_agent/extensions/langchain/embeddings/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .ernie import ErnieEmbeddings
|
erniebot-agent/erniebot_agent/extensions/langchain/embeddings/ernie.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, Dict, List, Optional
|
4 |
+
|
5 |
+
from langchain.pydantic_v1 import BaseModel, root_validator
|
6 |
+
from langchain.schema.embeddings import Embeddings
|
7 |
+
from langchain.utils import get_from_dict_or_env
|
8 |
+
|
9 |
+
|
10 |
+
class ErnieEmbeddings(BaseModel, Embeddings):
|
11 |
+
"""ERNIE embedding models.
|
12 |
+
|
13 |
+
To use, you should have the ``erniebot`` python package installed, and the
|
14 |
+
environment variable ``EB_ACCESS_TOKEN`` set with your AI Studio access token.
|
15 |
+
|
16 |
+
Example:
|
17 |
+
.. code-block:: python
|
18 |
+
from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
|
19 |
+
ernie_embeddings = ErnieEmbeddings()
|
20 |
+
"""
|
21 |
+
|
22 |
+
client: Any = None
|
23 |
+
max_retries: int = 6
|
24 |
+
"""Maximum number of retries to make when generating."""
|
25 |
+
chunk_size: int = 16
|
26 |
+
"""Chunk size to use when the input is a list of texts."""
|
27 |
+
aistudio_access_token: Optional[str] = None
|
28 |
+
"""AI Studio access token."""
|
29 |
+
model: str = "ernie-text-embedding"
|
30 |
+
"""Model to use."""
|
31 |
+
request_timeout: Optional[int] = 60
|
32 |
+
"""How many seconds to wait for the server to send data before giving up."""
|
33 |
+
|
34 |
+
ernie_client_id: Optional[str] = None
|
35 |
+
ernie_client_secret: Optional[str] = None
|
36 |
+
"""For raising deprecation warnings."""
|
37 |
+
|
38 |
+
@root_validator()
|
39 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
40 |
+
values["aistudio_access_token"] = get_from_dict_or_env(
|
41 |
+
values,
|
42 |
+
"aistudio_access_token",
|
43 |
+
"EB_ACCESS_TOKEN",
|
44 |
+
)
|
45 |
+
|
46 |
+
try:
|
47 |
+
import erniebot
|
48 |
+
|
49 |
+
values["client"] = erniebot.Embedding
|
50 |
+
except ImportError:
|
51 |
+
raise ImportError(
|
52 |
+
"Could not import erniebot python package. Please install it with `pip install erniebot`."
|
53 |
+
)
|
54 |
+
return values
|
55 |
+
|
56 |
+
def embed_query(self, text: str) -> List[float]:
|
57 |
+
resp = self.embed_documents([text])
|
58 |
+
return resp[0]
|
59 |
+
|
60 |
+
async def aembed_query(self, text: str) -> List[float]:
|
61 |
+
embeddings = await self.aembed_documents([text])
|
62 |
+
return embeddings[0]
|
63 |
+
|
64 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
65 |
+
text_in_chunks = [texts[i : i + self.chunk_size] for i in range(0, len(texts), self.chunk_size)]
|
66 |
+
lst = []
|
67 |
+
for chunk in text_in_chunks:
|
68 |
+
resp = self.client.create(_config_=self._get_auth_config(), input=chunk, model=self.model)
|
69 |
+
lst.extend([res["embedding"] for res in resp["data"]])
|
70 |
+
return lst
|
71 |
+
|
72 |
+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
73 |
+
text_in_chunks = [texts[i : i + self.chunk_size] for i in range(0, len(texts), self.chunk_size)]
|
74 |
+
lst = []
|
75 |
+
for chunk in text_in_chunks:
|
76 |
+
resp = await self.client.acreate(_config_=self._get_auth_config(), input=chunk, model=self.model)
|
77 |
+
for res in resp["data"]:
|
78 |
+
lst.extend([res["embedding"]])
|
79 |
+
return lst
|
80 |
+
|
81 |
+
def _get_auth_config(self) -> dict:
|
82 |
+
return {"api_type": "aistudio", "access_token": self.aistudio_access_token}
|
erniebot-agent/erniebot_agent/extensions/langchain/llms/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .erniebot import ErnieBot
|
erniebot-agent/erniebot_agent/extensions/langchain/llms/erniebot.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import (
|
4 |
+
Any,
|
5 |
+
AsyncIterator,
|
6 |
+
Callable,
|
7 |
+
Dict,
|
8 |
+
Iterator,
|
9 |
+
List,
|
10 |
+
Mapping,
|
11 |
+
Optional,
|
12 |
+
Type,
|
13 |
+
Union,
|
14 |
+
)
|
15 |
+
|
16 |
+
from langchain.callbacks.manager import (
|
17 |
+
AsyncCallbackManagerForLLMRun,
|
18 |
+
CallbackManagerForLLMRun,
|
19 |
+
)
|
20 |
+
from langchain.llms.base import LLM, create_base_retry_decorator
|
21 |
+
from langchain.llms.utils import enforce_stop_tokens
|
22 |
+
from langchain.pydantic_v1 import Field, root_validator
|
23 |
+
from langchain.schema.output import GenerationChunk
|
24 |
+
from langchain.utils import get_from_dict_or_env
|
25 |
+
|
26 |
+
|
27 |
+
class ErnieBot(LLM):
|
28 |
+
"""ERNIE Bot large language models.
|
29 |
+
|
30 |
+
To use, you should have the ``erniebot`` python package installed, and the
|
31 |
+
environment variable ``EB_ACCESS_TOKEN`` set with your AI Studio access token.
|
32 |
+
|
33 |
+
Example:
|
34 |
+
.. code-block:: python
|
35 |
+
|
36 |
+
from erniebot_agent.extensions.langchain.llms import ErnieBot
|
37 |
+
erniebot = ErnieBot(model="ernie-bot")
|
38 |
+
"""
|
39 |
+
|
40 |
+
client: Any = None
|
41 |
+
max_retries: int = 6
|
42 |
+
"""Maximum number of retries to make when generating."""
|
43 |
+
aistudio_access_token: Optional[str] = None
|
44 |
+
"""AI Studio access token."""
|
45 |
+
streaming: Optional[bool] = False
|
46 |
+
"""Whether to stream the results or not."""
|
47 |
+
model: str = "ernie-bot"
|
48 |
+
"""Model to use."""
|
49 |
+
top_p: Optional[float] = 0.8
|
50 |
+
"""Parameter of nucleus sampling that affects the diversity of generated content."""
|
51 |
+
temperature: Optional[float] = 0.95
|
52 |
+
"""Sampling temperature to use."""
|
53 |
+
penalty_score: Optional[float] = 1
|
54 |
+
"""Penalty assigned to tokens that have been generated."""
|
55 |
+
request_timeout: Optional[int] = 60
|
56 |
+
"""How many seconds to wait for the server to send data before giving up."""
|
57 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
58 |
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
59 |
+
|
60 |
+
@property
|
61 |
+
def _default_params(self) -> Dict[str, Any]:
|
62 |
+
"""Get the default parameters for calling ERNIE Bot API."""
|
63 |
+
normal_params = {
|
64 |
+
"model": self.model,
|
65 |
+
"top_p": self.top_p,
|
66 |
+
"temperature": self.temperature,
|
67 |
+
"penalty_score": self.penalty_score,
|
68 |
+
"request_timeout": self.request_timeout,
|
69 |
+
}
|
70 |
+
return {**normal_params, **self.model_kwargs}
|
71 |
+
|
72 |
+
@property
|
73 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
74 |
+
return self._default_params
|
75 |
+
|
76 |
+
@property
|
77 |
+
def _invocation_params(self) -> Dict[str, Any]:
|
78 |
+
"""Get the parameters used to invoke the model."""
|
79 |
+
auth_cfg: Dict[str, Optional[str]] = {
|
80 |
+
"api_type": "aistudio",
|
81 |
+
"access_token": self.aistudio_access_token,
|
82 |
+
}
|
83 |
+
return {**{"_config_": auth_cfg}, **self._default_params}
|
84 |
+
|
85 |
+
@property
|
86 |
+
def _llm_type(self) -> str:
|
87 |
+
"""Return type of llm."""
|
88 |
+
return "erniebot"
|
89 |
+
|
90 |
+
@root_validator()
|
91 |
+
def validate_enviroment(cls, values: Dict) -> Dict:
|
92 |
+
values["aistudio_access_token"] = get_from_dict_or_env(
|
93 |
+
values,
|
94 |
+
"aistudio_access_token",
|
95 |
+
"EB_ACCESS_TOKEN",
|
96 |
+
)
|
97 |
+
|
98 |
+
try:
|
99 |
+
import erniebot
|
100 |
+
|
101 |
+
values["client"] = erniebot.ChatCompletion
|
102 |
+
except ImportError:
|
103 |
+
raise ImportError(
|
104 |
+
"Could not import erniebot python package. Please install it with `pip install erniebot`."
|
105 |
+
)
|
106 |
+
return values
|
107 |
+
|
108 |
+
def _call(
|
109 |
+
self,
|
110 |
+
prompt: str,
|
111 |
+
stop: Optional[List[str]] = None,
|
112 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
113 |
+
**kwargs: Any,
|
114 |
+
) -> str:
|
115 |
+
if self.streaming:
|
116 |
+
text = ""
|
117 |
+
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
118 |
+
text += chunk.text
|
119 |
+
return text
|
120 |
+
else:
|
121 |
+
params = self._invocation_params
|
122 |
+
params.update(kwargs)
|
123 |
+
params["messages"] = [self._build_user_message_from_prompt(prompt)]
|
124 |
+
params["stream"] = False
|
125 |
+
response = _create_completion_with_retry(self, run_manager=run_manager, **params)
|
126 |
+
text = response["result"]
|
127 |
+
if stop is not None:
|
128 |
+
text = enforce_stop_tokens(text, stop)
|
129 |
+
return text
|
130 |
+
|
131 |
+
async def _acall(
|
132 |
+
self,
|
133 |
+
prompt: str,
|
134 |
+
stop: Optional[List[str]] = None,
|
135 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
136 |
+
**kwargs: Any,
|
137 |
+
) -> str:
|
138 |
+
if self.streaming:
|
139 |
+
text = ""
|
140 |
+
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
|
141 |
+
text += chunk.text
|
142 |
+
return text
|
143 |
+
else:
|
144 |
+
params = self._invocation_params
|
145 |
+
params.update(kwargs)
|
146 |
+
params["messages"] = [self._build_user_message_from_prompt(prompt)]
|
147 |
+
params["stream"] = False
|
148 |
+
response = await _acreate_completion_with_retry(self, run_manager=run_manager, **params)
|
149 |
+
text = response["result"]
|
150 |
+
if stop is not None:
|
151 |
+
text = enforce_stop_tokens(text, stop)
|
152 |
+
return text
|
153 |
+
|
154 |
+
def _stream(
|
155 |
+
self,
|
156 |
+
prompt: str,
|
157 |
+
stop: Optional[List[str]] = None,
|
158 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
159 |
+
**kwargs: Any,
|
160 |
+
) -> Iterator[GenerationChunk]:
|
161 |
+
if stop is not None:
|
162 |
+
raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
|
163 |
+
params = self._invocation_params
|
164 |
+
params.update(kwargs)
|
165 |
+
params["messages"] = [self._build_user_message_from_prompt(prompt)]
|
166 |
+
params["stream"] = True
|
167 |
+
for resp in _create_completion_with_retry(self, run_manager=run_manager, **params):
|
168 |
+
chunk = self._build_chunk_from_response(resp)
|
169 |
+
yield chunk
|
170 |
+
if run_manager:
|
171 |
+
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
172 |
+
|
173 |
+
async def _astream(
|
174 |
+
self,
|
175 |
+
prompt: str,
|
176 |
+
stop: Optional[List[str]] = None,
|
177 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
178 |
+
**kwargs: Any,
|
179 |
+
) -> AsyncIterator[GenerationChunk]:
|
180 |
+
if stop is not None:
|
181 |
+
raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
|
182 |
+
params = self._invocation_params
|
183 |
+
params.update(kwargs)
|
184 |
+
params["messages"] = [self._build_user_message_from_prompt(prompt)]
|
185 |
+
params["stream"] = True
|
186 |
+
async for resp in await _acreate_completion_with_retry(self, run_manager=run_manager, **params):
|
187 |
+
chunk = self._build_chunk_from_response(resp)
|
188 |
+
yield chunk
|
189 |
+
if run_manager:
|
190 |
+
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
191 |
+
|
192 |
+
def _build_chunk_from_response(self, response: Mapping[str, Any]) -> GenerationChunk:
|
193 |
+
return GenerationChunk(text=response["result"])
|
194 |
+
|
195 |
+
def _build_user_message_from_prompt(self, prompt: str) -> Dict[str, str]:
|
196 |
+
return {"role": "user", "content": prompt}
|
197 |
+
|
198 |
+
|
199 |
+
def _create_completion_with_retry(
|
200 |
+
llm: ErnieBot,
|
201 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
202 |
+
**kwargs: Any,
|
203 |
+
) -> Any:
|
204 |
+
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
205 |
+
|
206 |
+
@retry_decorator
|
207 |
+
def _client_create(**kwargs: Any) -> Any:
|
208 |
+
return llm.client.create(**kwargs)
|
209 |
+
|
210 |
+
return _client_create(**kwargs)
|
211 |
+
|
212 |
+
|
213 |
+
async def _acreate_completion_with_retry(
|
214 |
+
llm: ErnieBot,
|
215 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
216 |
+
**kwargs: Any,
|
217 |
+
) -> Any:
|
218 |
+
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
219 |
+
|
220 |
+
@retry_decorator
|
221 |
+
async def _client_acreate(**kwargs: Any) -> Any:
|
222 |
+
return await llm.client.acreate(**kwargs)
|
223 |
+
|
224 |
+
return await _client_acreate(**kwargs)
|
225 |
+
|
226 |
+
|
227 |
+
def _create_retry_decorator(
|
228 |
+
llm: ErnieBot,
|
229 |
+
run_manager: Optional[Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]] = None,
|
230 |
+
) -> Callable[[Any], Any]:
|
231 |
+
import erniebot
|
232 |
+
|
233 |
+
errors: List[Type[BaseException]] = [
|
234 |
+
erniebot.errors.TimeoutError,
|
235 |
+
erniebot.errors.RequestLimitError,
|
236 |
+
]
|
237 |
+
return create_base_retry_decorator(
|
238 |
+
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
239 |
+
)
|
erniebot-agent/erniebot_agent/file_io/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
erniebot-agent/erniebot_agent/file_io/base.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import abc
|
16 |
+
|
17 |
+
|
18 |
+
class File(metaclass=abc.ABCMeta):
|
19 |
+
def __init__(self, id: str, filename: str, created_at: int) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.id = id
|
22 |
+
self.filename = filename
|
23 |
+
self.created_at = created_at
|
24 |
+
|
25 |
+
def __eq__(self, other: object) -> bool:
|
26 |
+
if isinstance(other, File):
|
27 |
+
return self.id == other.id
|
28 |
+
else:
|
29 |
+
return False
|
30 |
+
|
31 |
+
def __repr__(self) -> str:
|
32 |
+
attrs_str = self._get_attrs_str()
|
33 |
+
return f"<{self.__class__.__name__} {attrs_str}>"
|
34 |
+
|
35 |
+
@abc.abstractmethod
|
36 |
+
async def read_contents(self) -> bytes:
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
def _get_attrs_str(self) -> str:
|
40 |
+
return ", ".join(
|
41 |
+
[
|
42 |
+
f"id: {repr(self.id)}",
|
43 |
+
f"filename: {repr(self.filename)}",
|
44 |
+
f"created_at: {repr(self.created_at)}",
|
45 |
+
]
|
46 |
+
)
|
erniebot-agent/erniebot_agent/file_io/file_manager.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import pathlib
|
17 |
+
import uuid
|
18 |
+
from typing import Literal, Optional, Union, overload
|
19 |
+
|
20 |
+
import anyio
|
21 |
+
from erniebot_agent.file_io.base import File
|
22 |
+
from erniebot_agent.file_io.file_registry import FileRegistry, get_file_registry
|
23 |
+
from erniebot_agent.file_io.local_file import LocalFile, create_local_file_from_path
|
24 |
+
from erniebot_agent.file_io.remote_file import RemoteFile, RemoteFileClient
|
25 |
+
from erniebot_agent.utils.temp_file import create_tracked_temp_dir
|
26 |
+
from typing_extensions import TypeAlias
|
27 |
+
|
28 |
+
_PathType: TypeAlias = Union[str, os.PathLike]
|
29 |
+
|
30 |
+
|
31 |
+
class FileManager(object):
|
32 |
+
_remote_file_client: Optional[RemoteFileClient]
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
remote_file_client: Optional[RemoteFileClient] = None,
|
37 |
+
*,
|
38 |
+
auto_register: bool = True,
|
39 |
+
save_dir: Optional[_PathType] = None,
|
40 |
+
) -> None:
|
41 |
+
super().__init__()
|
42 |
+
if remote_file_client is not None:
|
43 |
+
self._remote_file_client = remote_file_client
|
44 |
+
else:
|
45 |
+
self._remote_file_client = None
|
46 |
+
self._auto_register = auto_register
|
47 |
+
if save_dir is not None:
|
48 |
+
self._save_dir = pathlib.Path(save_dir)
|
49 |
+
else:
|
50 |
+
# This can be done lazily, but we need to be careful about race conditions.
|
51 |
+
self._save_dir = create_tracked_temp_dir()
|
52 |
+
|
53 |
+
self._file_registry = get_file_registry()
|
54 |
+
|
55 |
+
@property
|
56 |
+
def registry(self) -> FileRegistry:
|
57 |
+
return self._file_registry
|
58 |
+
|
59 |
+
@property
|
60 |
+
def remote_file_client(self) -> RemoteFileClient:
|
61 |
+
if self._remote_file_client is None:
|
62 |
+
raise AttributeError("No remote file client is set.")
|
63 |
+
else:
|
64 |
+
return self._remote_file_client
|
65 |
+
|
66 |
+
@overload
|
67 |
+
async def create_file_from_path(
|
68 |
+
self, file_path: _PathType, *, file_type: Literal["local"] = ...
|
69 |
+
) -> LocalFile:
|
70 |
+
...
|
71 |
+
|
72 |
+
@overload
|
73 |
+
async def create_file_from_path(
|
74 |
+
self, file_path: _PathType, *, file_type: Literal["remote"]
|
75 |
+
) -> RemoteFile:
|
76 |
+
...
|
77 |
+
|
78 |
+
async def create_file_from_path(
|
79 |
+
self, file_path: _PathType, *, file_type: Literal["local", "remote"] = "local"
|
80 |
+
) -> Union[LocalFile, RemoteFile]:
|
81 |
+
file: Union[LocalFile, RemoteFile]
|
82 |
+
if file_type == "local":
|
83 |
+
file = await self.create_local_file_from_path(file_path)
|
84 |
+
elif file_type == "remote":
|
85 |
+
file = await self.create_remote_file_from_path(file_path)
|
86 |
+
else:
|
87 |
+
raise ValueError(f"Unsupported file type: {file_type}")
|
88 |
+
return file
|
89 |
+
|
90 |
+
async def create_local_file_from_path(self, file_path: _PathType) -> LocalFile:
|
91 |
+
file = create_local_file_from_path(pathlib.Path(file_path))
|
92 |
+
self._file_registry.register_file(file)
|
93 |
+
return file
|
94 |
+
|
95 |
+
async def create_remote_file_from_path(self, file_path: _PathType) -> RemoteFile:
|
96 |
+
file = await self.remote_file_client.upload_file(pathlib.Path(file_path))
|
97 |
+
if self._auto_register:
|
98 |
+
self._file_registry.register_file(file)
|
99 |
+
return file
|
100 |
+
|
101 |
+
@overload
|
102 |
+
async def create_file_from_bytes(
|
103 |
+
self, file_contents: bytes, filename: str, *, file_type: Literal["local"] = ...
|
104 |
+
) -> LocalFile:
|
105 |
+
...
|
106 |
+
|
107 |
+
@overload
|
108 |
+
async def create_file_from_bytes(
|
109 |
+
self, file_contents: bytes, filename: str, *, file_type: Literal["remote"]
|
110 |
+
) -> RemoteFile:
|
111 |
+
...
|
112 |
+
|
113 |
+
async def create_file_from_bytes(
|
114 |
+
self, file_contents: bytes, filename: str, *, file_type: Literal["local", "remote"] = "local"
|
115 |
+
) -> Union[LocalFile, RemoteFile]:
|
116 |
+
# Can we do this with in-memory files?
|
117 |
+
file_path = self._fs_create_file(
|
118 |
+
prefix=pathlib.PurePath(filename).stem, suffix=pathlib.PurePath(filename).suffix
|
119 |
+
)
|
120 |
+
async with await anyio.open_file(file_path, "wb") as f:
|
121 |
+
await f.write(file_contents)
|
122 |
+
file = await self.create_file_from_path(file_path, file_type=file_type)
|
123 |
+
return file
|
124 |
+
|
125 |
+
async def retrieve_remote_file_by_id(self, file_id: str) -> RemoteFile:
|
126 |
+
file = await self.remote_file_client.retrieve_file(file_id)
|
127 |
+
if self._auto_register:
|
128 |
+
self._file_registry.register_file(file)
|
129 |
+
return file
|
130 |
+
|
131 |
+
def look_up_file_by_id(self, file_id: str) -> Optional[File]:
|
132 |
+
return self._file_registry.look_up_file(file_id)
|
133 |
+
|
134 |
+
def _fs_create_file(self, prefix: Optional[str] = None, suffix: Optional[str] = None) -> pathlib.Path:
|
135 |
+
filename = f"{prefix or ''}{str(uuid.uuid4())}{suffix or ''}"
|
136 |
+
file_path = self._save_dir / filename
|
137 |
+
file_path.touch()
|
138 |
+
return file_path
|
erniebot-agent/erniebot_agent/file_io/file_registry.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import threading
|
16 |
+
from typing import Dict, List, Optional
|
17 |
+
|
18 |
+
from erniebot_agent.file_io.base import File
|
19 |
+
from erniebot_agent.utils.misc import Singleton
|
20 |
+
|
21 |
+
|
22 |
+
class FileRegistry(metaclass=Singleton):
|
23 |
+
def __init__(self) -> None:
|
24 |
+
super().__init__()
|
25 |
+
self._id_to_file: Dict[str, File] = {}
|
26 |
+
self._lock = threading.Lock()
|
27 |
+
|
28 |
+
def register_file(self, file: File) -> None:
|
29 |
+
file_id = file.id
|
30 |
+
with self._lock:
|
31 |
+
# Re-registering an existing file is allowed.
|
32 |
+
# We simply update the registry.
|
33 |
+
self._id_to_file[file_id] = file
|
34 |
+
|
35 |
+
def unregister_file(self, file: File) -> None:
|
36 |
+
file_id = file.id
|
37 |
+
with self._lock:
|
38 |
+
if file_id not in self._id_to_file:
|
39 |
+
raise RuntimeError(f"ID {repr(file_id)} is not registered.")
|
40 |
+
self._id_to_file.pop(file_id)
|
41 |
+
|
42 |
+
def look_up_file(self, file_id: str) -> Optional[File]:
|
43 |
+
with self._lock:
|
44 |
+
return self._id_to_file.get(file_id, None)
|
45 |
+
|
46 |
+
def list_files(self) -> List[File]:
|
47 |
+
with self._lock:
|
48 |
+
return list(self._id_to_file.values())
|
49 |
+
|
50 |
+
|
51 |
+
_file_registry = FileRegistry()
|
52 |
+
|
53 |
+
|
54 |
+
def get_file_registry() -> FileRegistry:
|
55 |
+
return _file_registry
|
erniebot-agent/erniebot_agent/file_io/local_file.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import pathlib
|
16 |
+
import time
|
17 |
+
import uuid
|
18 |
+
|
19 |
+
import anyio
|
20 |
+
from erniebot_agent.file_io.base import File
|
21 |
+
from erniebot_agent.file_io.protocol import (
|
22 |
+
build_local_file_id_from_uuid,
|
23 |
+
is_local_file_id,
|
24 |
+
)
|
25 |
+
from erniebot_agent.utils.logging import logger
|
26 |
+
|
27 |
+
|
28 |
+
class LocalFile(File):
|
29 |
+
def __init__(self, id: str, filename: str, created_at: int, path: pathlib.Path) -> None:
|
30 |
+
if not is_local_file_id(id):
|
31 |
+
raise ValueError("Invalid file ID: {id}")
|
32 |
+
super().__init__(id=id, filename=filename, created_at=created_at)
|
33 |
+
self.path = path
|
34 |
+
|
35 |
+
async def read_contents(self) -> bytes:
|
36 |
+
return await anyio.Path(self.path).read_bytes()
|
37 |
+
|
38 |
+
def _get_attrs_str(self) -> str:
|
39 |
+
attrs_str = super()._get_attrs_str()
|
40 |
+
attrs_str += f", path: {repr(self.path)}"
|
41 |
+
return attrs_str
|
42 |
+
|
43 |
+
|
44 |
+
def create_local_file_from_path(file_path: pathlib.Path) -> LocalFile:
|
45 |
+
if not file_path.exists():
|
46 |
+
logger.warn("File %s does not exist.", file_path)
|
47 |
+
file_id = _generate_local_file_id()
|
48 |
+
filename = file_path.name
|
49 |
+
created_at = int(time.time())
|
50 |
+
file = LocalFile(id=file_id, filename=filename, created_at=created_at, path=file_path)
|
51 |
+
return file
|
52 |
+
|
53 |
+
|
54 |
+
def _generate_local_file_id():
|
55 |
+
return build_local_file_id_from_uuid(str(uuid.uuid1()))
|
erniebot-agent/erniebot_agent/file_io/protocol.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
_LOCAL_FILE_ID_PREFIX = "file-local-"
|
19 |
+
_REMOTE_FILE_ID_PREFIX = "file-remote-"
|
20 |
+
_UUID_PATTERN = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
21 |
+
_LOCAL_FILE_ID_PATTERN = _LOCAL_FILE_ID_PREFIX + _UUID_PATTERN
|
22 |
+
_REMOTE_FILE_ID_PATTERN = _REMOTE_FILE_ID_PREFIX + _UUID_PATTERN
|
23 |
+
|
24 |
+
_compiled_local_file_id_pattern = re.compile(_LOCAL_FILE_ID_PATTERN)
|
25 |
+
_compiled_remote_file_id_pattern = re.compile(_REMOTE_FILE_ID_PATTERN)
|
26 |
+
|
27 |
+
|
28 |
+
def build_local_file_id_from_uuid(uuid: str) -> str:
|
29 |
+
return _LOCAL_FILE_ID_PREFIX + uuid
|
30 |
+
|
31 |
+
|
32 |
+
def build_remote_file_id_from_uuid(uuid: str) -> str:
|
33 |
+
return _REMOTE_FILE_ID_PREFIX + uuid
|
34 |
+
|
35 |
+
|
36 |
+
def is_file_id(str_: str) -> bool:
|
37 |
+
return is_local_file_id(str_) or is_remote_file_id(str_)
|
38 |
+
|
39 |
+
|
40 |
+
def is_local_file_id(str_: str) -> bool:
|
41 |
+
return _compiled_local_file_id_pattern.fullmatch(str_) is not None
|
42 |
+
|
43 |
+
|
44 |
+
def is_remote_file_id(str_: str) -> bool:
|
45 |
+
return _compiled_remote_file_id_pattern.fullmatch(str_) is not None
|
46 |
+
|
47 |
+
|
48 |
+
def extract_file_ids(str_: str) -> List[str]:
|
49 |
+
return extract_local_file_ids(str_) + extract_remote_file_ids(str_)
|
50 |
+
|
51 |
+
|
52 |
+
def extract_local_file_ids(str_: str) -> List[str]:
|
53 |
+
return _compiled_local_file_id_pattern.findall(str_)
|
54 |
+
|
55 |
+
|
56 |
+
def extract_remote_file_ids(str_: str) -> List[str]:
|
57 |
+
return _compiled_remote_file_id_pattern.findall(str_)
|
erniebot-agent/erniebot_agent/file_io/remote_file.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import abc
|
16 |
+
import asyncio
|
17 |
+
import functools
|
18 |
+
import pathlib
|
19 |
+
import time
|
20 |
+
import uuid
|
21 |
+
from typing import ClassVar, Dict, List
|
22 |
+
|
23 |
+
import anyio
|
24 |
+
from baidubce.auth.bce_credentials import BceCredentials
|
25 |
+
from baidubce.bce_client_configuration import BceClientConfiguration
|
26 |
+
from baidubce.services.bos.bos_client import BosClient
|
27 |
+
from erniebot_agent.file_io.base import File
|
28 |
+
from erniebot_agent.file_io.protocol import (
|
29 |
+
build_remote_file_id_from_uuid,
|
30 |
+
is_remote_file_id,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class RemoteFile(File):
|
35 |
+
def __init__(self, id: str, filename: str, created_at: int, client: "RemoteFileClient") -> None:
|
36 |
+
if not is_remote_file_id(id):
|
37 |
+
raise ValueError("Invalid file ID: {id}")
|
38 |
+
super().__init__(id=id, filename=filename, created_at=created_at)
|
39 |
+
self._client = client
|
40 |
+
|
41 |
+
async def read_contents(self) -> bytes:
|
42 |
+
file_contents = await self._client.retrieve_file_contents(self.id)
|
43 |
+
return file_contents
|
44 |
+
|
45 |
+
async def delete(self) -> None:
|
46 |
+
await self._client.delete_file(self.id)
|
47 |
+
|
48 |
+
|
49 |
+
class RemoteFileClient(metaclass=abc.ABCMeta):
|
50 |
+
@abc.abstractmethod
|
51 |
+
async def upload_file(self, file_path: pathlib.Path) -> RemoteFile:
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
@abc.abstractmethod
|
55 |
+
async def retrieve_file(self, file_id: str) -> RemoteFile:
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
@abc.abstractmethod
|
59 |
+
async def retrieve_file_contents(self, file_id: str) -> bytes:
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
@abc.abstractmethod
|
63 |
+
async def list_files(self) -> List[RemoteFile]:
|
64 |
+
raise NotImplementedError
|
65 |
+
|
66 |
+
@abc.abstractmethod
|
67 |
+
async def delete_file(self, file_id: str) -> None:
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
|
71 |
+
class BOSFileClient(RemoteFileClient):
|
72 |
+
_ENDPOINT: ClassVar[str] = "bj.bcebos.com"
|
73 |
+
|
74 |
+
def __init__(self, ak: str, sk: str, bucket_name: str, prefix: str) -> None:
|
75 |
+
super().__init__()
|
76 |
+
self.bucket_name = bucket_name
|
77 |
+
self.prefix = prefix
|
78 |
+
config = BceClientConfiguration(credentials=BceCredentials(ak, sk), endpoint=self._ENDPOINT)
|
79 |
+
self._bos_client = BosClient(config=config)
|
80 |
+
|
81 |
+
async def upload_file(self, file_path: pathlib.Path) -> RemoteFile:
|
82 |
+
file_id = self._generate_file_id()
|
83 |
+
filename = file_path.name
|
84 |
+
created_at = int(time.time())
|
85 |
+
user_metadata: Dict[str, str] = {"id": file_id, "filename": filename, "created_at": str(created_at)}
|
86 |
+
async with await anyio.open_file(file_path, mode="rb") as f:
|
87 |
+
data = await f.read()
|
88 |
+
loop = asyncio.get_running_loop()
|
89 |
+
await loop.run_in_executor(
|
90 |
+
None,
|
91 |
+
functools.partial(
|
92 |
+
self._bos_client.put_object_from_string,
|
93 |
+
bucket=self.bucket_name,
|
94 |
+
key=self._get_key(file_id),
|
95 |
+
data=data,
|
96 |
+
user_metadata=user_metadata,
|
97 |
+
),
|
98 |
+
)
|
99 |
+
return RemoteFile(
|
100 |
+
id=file_id,
|
101 |
+
filename=filename,
|
102 |
+
created_at=created_at,
|
103 |
+
client=self,
|
104 |
+
)
|
105 |
+
|
106 |
+
async def retrieve_file(self, file_id: str) -> RemoteFile:
|
107 |
+
loop = asyncio.get_running_loop()
|
108 |
+
response = await loop.run_in_executor(
|
109 |
+
None,
|
110 |
+
functools.partial(
|
111 |
+
self._bos_client.get_object_meta_data, self.bucket_name, self._get_key(file_id)
|
112 |
+
),
|
113 |
+
)
|
114 |
+
user_metadata = {
|
115 |
+
"id": response.metadata.bce_meta_id,
|
116 |
+
"filename": response.metadata.bce_meta_filename,
|
117 |
+
"created_at": int(response.metadata.bce_meta_created_at),
|
118 |
+
}
|
119 |
+
if file_id != user_metadata["id"]:
|
120 |
+
raise RuntimeError("`file_id` is not the same as the one in metadata.")
|
121 |
+
|
122 |
+
return RemoteFile(
|
123 |
+
id=user_metadata["id"],
|
124 |
+
filename=user_metadata["filename"],
|
125 |
+
created_at=user_metadata["created_at"],
|
126 |
+
client=self,
|
127 |
+
)
|
128 |
+
|
129 |
+
async def retrieve_file_contents(self, file_id: str) -> bytes:
|
130 |
+
loop = asyncio.get_running_loop()
|
131 |
+
result = await loop.run_in_executor(
|
132 |
+
None,
|
133 |
+
functools.partial(
|
134 |
+
self._bos_client.get_object_as_string, self.bucket_name, self._get_key(file_id)
|
135 |
+
),
|
136 |
+
)
|
137 |
+
return result
|
138 |
+
|
139 |
+
async def list_files(self) -> List[RemoteFile]:
|
140 |
+
raise RuntimeError(f"`{self.__class__.__name__}.list_files` is not supported.")
|
141 |
+
|
142 |
+
async def delete_file(self, file_id: str) -> None:
|
143 |
+
loop = asyncio.get_running_loop()
|
144 |
+
await loop.run_in_executor(
|
145 |
+
None, functools.partial(self._bos_client.delete_object, self.bucket_name, self._get_key(file_id))
|
146 |
+
)
|
147 |
+
|
148 |
+
def _get_key(self, file_id: str) -> str:
|
149 |
+
return self.prefix + file_id
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def _generate_file_id() -> str:
|
153 |
+
return build_remote_file_id_from_uuid(str(uuid.uuid1()))
|
erniebot-agent/erniebot_agent/memory/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .base import Memory
|
16 |
+
from .limit_token_memory import LimitTokensMemory
|
17 |
+
from .sliding_window_memory import SlidingWindowMemory
|
18 |
+
from .whole_memory import WholeMemory
|
erniebot-agent/erniebot_agent/memory/base.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import List, Optional, Union
|
16 |
+
|
17 |
+
from erniebot_agent.messages import AIMessage, Message, SystemMessage
|
18 |
+
|
19 |
+
|
20 |
+
class MessageManager:
|
21 |
+
"""
|
22 |
+
Messages Manager.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self) -> None:
|
26 |
+
self.messages: List[Message] = []
|
27 |
+
self._system_message: Union[SystemMessage, None] = None
|
28 |
+
|
29 |
+
@property
|
30 |
+
def system_message(self) -> Optional[Message]:
|
31 |
+
"""
|
32 |
+
The message manager have only one system message.
|
33 |
+
|
34 |
+
return: Message or None
|
35 |
+
"""
|
36 |
+
return self._system_message
|
37 |
+
|
38 |
+
@system_message.setter
|
39 |
+
def system_message(self, message: SystemMessage) -> None:
|
40 |
+
if self._system_message is not None:
|
41 |
+
Warning("system message has been set, the previous one will be replaced")
|
42 |
+
|
43 |
+
self._system_message = message
|
44 |
+
|
45 |
+
def add_messages(self, messages: List[Message]) -> None:
|
46 |
+
self.messages.extend(messages)
|
47 |
+
|
48 |
+
def add_message(self, message: Message) -> None:
|
49 |
+
if isinstance(message, SystemMessage):
|
50 |
+
self.system_message = message
|
51 |
+
else:
|
52 |
+
self.messages.append(message)
|
53 |
+
|
54 |
+
def pop_message(self) -> Message:
|
55 |
+
return self.messages.pop(0)
|
56 |
+
|
57 |
+
def clear_messages(self) -> None:
|
58 |
+
self.messages = []
|
59 |
+
|
60 |
+
def update_last_message_token_count(self, token_count: int):
|
61 |
+
if token_count == 0:
|
62 |
+
self.messages[-1].token_count = len(self.messages[-1].content)
|
63 |
+
else:
|
64 |
+
self.messages[-1].token_count = token_count
|
65 |
+
|
66 |
+
def retrieve_messages(self) -> List[Message]:
|
67 |
+
return self.messages
|
68 |
+
|
69 |
+
|
70 |
+
class Memory:
|
71 |
+
"""The base class of memory"""
|
72 |
+
|
73 |
+
def __init__(self):
|
74 |
+
self.msg_manager = MessageManager()
|
75 |
+
|
76 |
+
def add_messages(self, messages: List[Message]):
|
77 |
+
for message in messages:
|
78 |
+
self.add_message(message)
|
79 |
+
|
80 |
+
def add_message(self, message: Message):
|
81 |
+
if isinstance(message, AIMessage):
|
82 |
+
self.msg_manager.update_last_message_token_count(message.query_tokens_count)
|
83 |
+
self.msg_manager.add_message(message)
|
84 |
+
|
85 |
+
def get_messages(self) -> List[Message]:
|
86 |
+
return self.msg_manager.retrieve_messages()
|
87 |
+
|
88 |
+
def get_system_message(self) -> SystemMessage:
|
89 |
+
return self.msg_manager.system_message
|
90 |
+
|
91 |
+
def clear_chat_history(self):
|
92 |
+
self.msg_manager.clear_messages()
|
93 |
+
|
94 |
+
|
95 |
+
class WholeMemory(Memory):
|
96 |
+
"""The memory include all the messages"""
|
97 |
+
|
98 |
+
def __init__(self):
|
99 |
+
super().__init__()
|
erniebot-agent/erniebot_agent/memory/limit_token_memory.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from erniebot_agent.memory import Memory
|
17 |
+
from erniebot_agent.messages import AIMessage, Message
|
18 |
+
|
19 |
+
|
20 |
+
class LimitTokensMemory(Memory):
|
21 |
+
"""This class controls max tokens less than max_token_limit.
|
22 |
+
If tokens >= max_token_limit, pop message from memory.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, max_token_limit=6000):
|
26 |
+
super().__init__()
|
27 |
+
self.max_token_limit = max_token_limit
|
28 |
+
self.mem_token_count = 0
|
29 |
+
|
30 |
+
assert (
|
31 |
+
max_token_limit is None
|
32 |
+
) or max_token_limit > 0, "max_token_limit should be None or positive integer, \
|
33 |
+
but got {max_token_limit}".format(
|
34 |
+
max_token_limit=max_token_limit
|
35 |
+
)
|
36 |
+
|
37 |
+
def add_message(self, message: Message):
|
38 |
+
super().add_message(message)
|
39 |
+
# TODO(shiyutang): 仅在添加AIMessage时截断会导致HumanMessage传入到LLM时可能长度超限
|
40 |
+
# 最优方案为每条message产生时确定token_count,从而在每次加入message时都进行prune_message
|
41 |
+
if isinstance(message, AIMessage):
|
42 |
+
self.prune_message()
|
43 |
+
|
44 |
+
def prune_message(self):
|
45 |
+
self.mem_token_count += self.msg_manager.messages[-1].token_count
|
46 |
+
self.mem_token_count += self.msg_manager.messages[-2].token_count # add human message token length
|
47 |
+
if self.max_token_limit is not None:
|
48 |
+
while self.mem_token_count > self.max_token_limit:
|
49 |
+
deleted_message = self.msg_manager.pop_message()
|
50 |
+
self.mem_token_count -= deleted_message.token_count
|
51 |
+
else:
|
52 |
+
# if delete all
|
53 |
+
if len(self.get_messages()) == 0:
|
54 |
+
raise RuntimeError(
|
55 |
+
"The messsage is now empty. \
|
56 |
+
It indicates {} which takes up {} tokens and exeeded {} tokens.".format(
|
57 |
+
deleted_message, len(deleted_message.content), self.max_token_limit
|
58 |
+
)
|
59 |
+
)
|
erniebot-agent/erniebot_agent/memory/sliding_window_memory.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from erniebot_agent.memory import Memory
|
16 |
+
from erniebot_agent.messages import Message
|
17 |
+
|
18 |
+
|
19 |
+
class SlidingWindowMemory(Memory):
|
20 |
+
"""This class controls max number of messages."""
|
21 |
+
|
22 |
+
def __init__(self, max_num_message: int):
|
23 |
+
super().__init__()
|
24 |
+
self.max_num_message = max_num_message
|
25 |
+
|
26 |
+
assert (isinstance(max_num_message, int)) and (
|
27 |
+
max_num_message > 0
|
28 |
+
), "max_num_message should be positive integer, but got {max_token_limit}".format(
|
29 |
+
max_token_limit=max_num_message
|
30 |
+
)
|
31 |
+
|
32 |
+
def add_message(self, message: Message):
|
33 |
+
super().add_message(message=message)
|
34 |
+
self.prune_message()
|
35 |
+
|
36 |
+
def prune_message(self):
|
37 |
+
while len(self.get_messages()) > self.max_num_message:
|
38 |
+
self.msg_manager.pop_message()
|
39 |
+
# `messages` must have an odd number of elements.
|
40 |
+
if len(self.get_messages()) % 2 == 0:
|
41 |
+
self.msg_manager.pop_message()
|
erniebot-agent/erniebot_agent/memory/whole_memory.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from erniebot_agent.memory.base import Memory
|
16 |
+
|
17 |
+
|
18 |
+
class WholeMemory(Memory):
|
19 |
+
""""""
|
erniebot-agent/erniebot_agent/messages.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License
|
13 |
+
|
14 |
+
from typing import Dict, List, Optional, TypedDict
|
15 |
+
|
16 |
+
import erniebot.utils.token_helper as token_helper
|
17 |
+
|
18 |
+
|
19 |
+
class Message:
|
20 |
+
"""The base class of a message."""
|
21 |
+
|
22 |
+
def __init__(self, role: str, content: str, token_count: Optional[int] = None):
|
23 |
+
self.role = role
|
24 |
+
self.content = content
|
25 |
+
self._token_count = token_count
|
26 |
+
self._param_names = ["role", "content"]
|
27 |
+
|
28 |
+
@property
|
29 |
+
def token_count(self):
|
30 |
+
"""Get the number of tokens of the message."""
|
31 |
+
if self._token_count is None:
|
32 |
+
raise AttributeError("The token count of the message has not been set.")
|
33 |
+
return self._token_count
|
34 |
+
|
35 |
+
@token_count.setter
|
36 |
+
def token_count(self, token_count: int):
|
37 |
+
"""Set the number of tokens of the message."""
|
38 |
+
if self._token_count is not None:
|
39 |
+
raise AttributeError("The token count of the message can only be set once.")
|
40 |
+
self._token_count = token_count
|
41 |
+
|
42 |
+
def to_dict(self) -> Dict[str, str]:
|
43 |
+
res = {}
|
44 |
+
for name in self._param_names:
|
45 |
+
res[name] = getattr(self, name)
|
46 |
+
return res
|
47 |
+
|
48 |
+
def __str__(self) -> str:
|
49 |
+
return f"<{self._get_attrs_str()}>"
|
50 |
+
|
51 |
+
def __repr__(self):
|
52 |
+
return f"<{self.__class__.__name__} {self._get_attrs_str()}>"
|
53 |
+
|
54 |
+
def _get_attrs_str(self) -> str:
|
55 |
+
parts: List[str] = []
|
56 |
+
for name in self._param_names:
|
57 |
+
value = getattr(self, name)
|
58 |
+
if value is not None and value != "":
|
59 |
+
parts.append(f"{name}: {repr(value)}")
|
60 |
+
if self._token_count is not None:
|
61 |
+
parts.append(f"token_count: {self._token_count}")
|
62 |
+
return ", ".join(parts)
|
63 |
+
|
64 |
+
|
65 |
+
class SystemMessage(Message):
|
66 |
+
"""A message from a human to set system information."""
|
67 |
+
|
68 |
+
def __init__(self, content: str):
|
69 |
+
super().__init__(role="system", content=content, token_count=len(content))
|
70 |
+
|
71 |
+
|
72 |
+
class HumanMessage(Message):
|
73 |
+
"""A message from a human."""
|
74 |
+
|
75 |
+
def __init__(self, content: str):
|
76 |
+
super().__init__(role="user", content=content)
|
77 |
+
|
78 |
+
|
79 |
+
class FunctionCall(TypedDict):
|
80 |
+
name: str
|
81 |
+
thoughts: str
|
82 |
+
arguments: str
|
83 |
+
|
84 |
+
|
85 |
+
class TokenUsage(TypedDict):
|
86 |
+
prompt_tokens: int
|
87 |
+
completion_tokens: int
|
88 |
+
|
89 |
+
|
90 |
+
class AIMessage(Message):
|
91 |
+
"""A message from an assistant."""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
content: str,
|
96 |
+
function_call: Optional[FunctionCall],
|
97 |
+
token_usage: Optional[TokenUsage] = None,
|
98 |
+
):
|
99 |
+
if token_usage is None:
|
100 |
+
prompt_tokens = 0
|
101 |
+
completion_tokens = token_helper.approx_num_tokens(content)
|
102 |
+
else:
|
103 |
+
prompt_tokens, completion_tokens = self._parse_token_count(token_usage)
|
104 |
+
super().__init__(role="assistant", content=content, token_count=completion_tokens)
|
105 |
+
self.function_call = function_call
|
106 |
+
self.query_tokens_count = prompt_tokens
|
107 |
+
self._param_names = ["role", "content", "function_call"]
|
108 |
+
|
109 |
+
def _parse_token_count(self, token_usage: TokenUsage):
|
110 |
+
"""Parse the token count information from LLM."""
|
111 |
+
return token_usage["prompt_tokens"], token_usage["completion_tokens"]
|
112 |
+
|
113 |
+
|
114 |
+
class FunctionMessage(Message):
|
115 |
+
"""A message from a human, containing the result of a function call."""
|
116 |
+
|
117 |
+
def __init__(self, name: str, content: str):
|
118 |
+
super().__init__(role="function", content=content)
|
119 |
+
self.name = name
|
120 |
+
self._param_names = ["role", "name", "content"]
|
121 |
+
|
122 |
+
|
123 |
+
class AIMessageChunk(AIMessage):
|
124 |
+
"""A message chunk from an assistant."""
|
erniebot-agent/erniebot_agent/prompt/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .base import BasePromptTemplate
|
16 |
+
from .prompt_template import PromptTemplate
|
erniebot-agent/erniebot_agent/prompt/base.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from abc import ABC, abstractmethod
|
15 |
+
from typing import List, Optional
|
16 |
+
|
17 |
+
|
18 |
+
class BasePromptTemplate(ABC):
|
19 |
+
def __init__(self, input_variables: Optional[List[str]]):
|
20 |
+
self.input_variables: Optional[List[str]] = input_variables
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def format(self, **kwargs):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def format_as_message(self, message_class, **kwargs):
|
28 |
+
raise NotImplementedError
|
erniebot-agent/erniebot_agent/prompt/prompt_template.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any, List, Optional
|
16 |
+
|
17 |
+
from erniebot_agent.messages import HumanMessage
|
18 |
+
from erniebot_agent.prompt import BasePromptTemplate
|
19 |
+
from jinja2 import Environment, meta
|
20 |
+
|
21 |
+
|
22 |
+
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
23 |
+
"""Format a template using jinja2."""
|
24 |
+
try:
|
25 |
+
from jinja2 import Template
|
26 |
+
except ImportError:
|
27 |
+
raise ImportError(
|
28 |
+
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
29 |
+
"Please install it with `pip install jinja2`."
|
30 |
+
)
|
31 |
+
|
32 |
+
return Template(template).render(**kwargs)
|
33 |
+
|
34 |
+
|
35 |
+
class PromptTemplate(BasePromptTemplate):
|
36 |
+
"""format the prompt for llm input."""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self, template: str, name: Optional[str] = None, input_variables: Optional[List[str]] = None
|
40 |
+
):
|
41 |
+
super().__init__(input_variables)
|
42 |
+
self.name = name
|
43 |
+
self.template = template
|
44 |
+
self.validate_template = True if input_variables is not None else False # todo: 验证模板是否正确
|
45 |
+
|
46 |
+
def format(self, **kwargs) -> str:
|
47 |
+
if self.validate_template:
|
48 |
+
error = self._validate_template()
|
49 |
+
if error:
|
50 |
+
raise KeyError("The input_variables of PromptTemplate and template are not match! " + error)
|
51 |
+
return jinja2_formatter(self.template, **kwargs)
|
52 |
+
|
53 |
+
def _validate_template(self):
|
54 |
+
"""
|
55 |
+
Validate that the input variables are valid for the template.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
template: The template string.
|
59 |
+
input_variables: The input variables.
|
60 |
+
"""
|
61 |
+
input_variables_set = set(self.input_variables)
|
62 |
+
env = Environment()
|
63 |
+
ast = env.parse(self.template)
|
64 |
+
valid_variables = meta.find_undeclared_variables(ast)
|
65 |
+
|
66 |
+
missing_variables = valid_variables - input_variables_set
|
67 |
+
extra_variables = input_variables_set - valid_variables
|
68 |
+
|
69 |
+
Error_message = ""
|
70 |
+
if missing_variables:
|
71 |
+
Error_message += f"The missing input variables: {missing_variables} "
|
72 |
+
|
73 |
+
if extra_variables:
|
74 |
+
Error_message += f"The extra input variables: {extra_variables}"
|
75 |
+
|
76 |
+
return Error_message
|
77 |
+
|
78 |
+
def format_as_message(self, **kwargs):
|
79 |
+
prompt = self.format(**kwargs)
|
80 |
+
return HumanMessage(content=prompt)
|
erniebot-agent/erniebot_agent/retrieval/__init__.py
ADDED
File without changes
|
erniebot-agent/erniebot_agent/retrieval/baizhong_search.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from typing import Any, Dict, List, Optional
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from erniebot_agent.utils.exception import BaizhongError
|
8 |
+
from erniebot_agent.utils.logging import logger
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from .document import Document
|
12 |
+
|
13 |
+
|
14 |
+
class BaizhongSearch:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
base_url: str,
|
18 |
+
project_name: Optional[str] = None,
|
19 |
+
remark: Optional[str] = None,
|
20 |
+
project_id: Optional[int] = None,
|
21 |
+
max_seq_length: int = 512,
|
22 |
+
) -> None:
|
23 |
+
self.base_url = base_url
|
24 |
+
self.max_seq_length = max_seq_length
|
25 |
+
if project_id is not None:
|
26 |
+
logger.info(f"Loading existing project with `project_id={project_id}`")
|
27 |
+
self.project_id = project_id
|
28 |
+
elif project_name is not None:
|
29 |
+
logger.info("Creating new project and schema")
|
30 |
+
self.index = self.create_project(project_name, remark)
|
31 |
+
logger.info("Project creation succeeded")
|
32 |
+
self.project_id = self.index["result"]["projectId"]
|
33 |
+
self.create_schema()
|
34 |
+
logger.info("Schema creation succeeded")
|
35 |
+
else:
|
36 |
+
raise BaizhongError("You must provide either a `project_name` or a `project_id`.")
|
37 |
+
|
38 |
+
def create_project(self, project_name: str, remark: Optional[str] = None):
|
39 |
+
"""
|
40 |
+
Create a project using the Baizhong API.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
dict: A dictionary containing information about the created project.
|
44 |
+
|
45 |
+
Raises:
|
46 |
+
BaizhongError: If the API request fails, this exception is raised with details about the error.
|
47 |
+
"""
|
48 |
+
json_data = {
|
49 |
+
"projectName": project_name,
|
50 |
+
"remark": remark,
|
51 |
+
}
|
52 |
+
res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project/add", json=json_data)
|
53 |
+
if res.status_code == 200:
|
54 |
+
result = res.json()
|
55 |
+
if result["errCode"] != 0:
|
56 |
+
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
|
57 |
+
return result
|
58 |
+
else:
|
59 |
+
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|
60 |
+
|
61 |
+
def create_schema(self):
|
62 |
+
"""
|
63 |
+
Create a schema for a project using the Baizhong API.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
dict: A dictionary containing information about the created schema.
|
67 |
+
|
68 |
+
Raises:
|
69 |
+
BaizhongError: If the API request fails, this exception is raised with details about the error.
|
70 |
+
"""
|
71 |
+
json_data = {
|
72 |
+
"projectId": self.project_id,
|
73 |
+
"schemaJson": {
|
74 |
+
"paraSize": self.max_seq_length,
|
75 |
+
"dataSegmentationMod": "neisou",
|
76 |
+
"storeType": "ElasticSearch",
|
77 |
+
"properties": {
|
78 |
+
"title": {"type": "text", "shortindex": True},
|
79 |
+
"content_se": {"type": "text", "longindex": True},
|
80 |
+
},
|
81 |
+
},
|
82 |
+
}
|
83 |
+
res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project-schema/create", json=json_data)
|
84 |
+
if res.status_code == 200:
|
85 |
+
result = res.json()
|
86 |
+
if result["errCode"] != 0:
|
87 |
+
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
|
88 |
+
return res.json()
|
89 |
+
else:
|
90 |
+
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|
91 |
+
|
92 |
+
def update_schema(
|
93 |
+
self,
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
Update the schema for a project using the Baizhong API.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
dict: A dictionary containing information about the updated schema.
|
100 |
+
|
101 |
+
Raises:
|
102 |
+
BaizhongError: If the API request fails, this exception is raised with details about the error.
|
103 |
+
"""
|
104 |
+
json_data = {
|
105 |
+
"projectId": self.project_id,
|
106 |
+
"schemaJson": {
|
107 |
+
"paraSize": self.max_seq_length,
|
108 |
+
"dataSegmentationMod": "neisou",
|
109 |
+
"storeType": "ElasticSearch",
|
110 |
+
"properties": {
|
111 |
+
"title": {"type": "text", "shortindex": True},
|
112 |
+
"content_se": {"type": "text", "longindex": True},
|
113 |
+
},
|
114 |
+
},
|
115 |
+
}
|
116 |
+
res = requests.post(f"{self.base_url}/baizhong/web-api/v2/project-schema/update", json=json_data)
|
117 |
+
status_code = res.status_code
|
118 |
+
if status_code == 200:
|
119 |
+
result = res.json()
|
120 |
+
if result["errCode"] != 0:
|
121 |
+
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
|
122 |
+
return result
|
123 |
+
else:
|
124 |
+
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|
125 |
+
|
126 |
+
def search(self, query: str, top_k: int = 10, filters: Optional[Dict[str, Any]] = None):
|
127 |
+
"""
|
128 |
+
Perform a search using the Baizhong common search API.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
query (str): The search query.
|
132 |
+
top_k (int, optional): The number of top results to retrieve (default is 10).
|
133 |
+
filters (Optional[Dict[str, Any]], optional): Additional filters to apply to the search query
|
134 |
+
(default is None).
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
List[Dict[str, Any]]: A list of dictionaries containing search results.
|
138 |
+
|
139 |
+
Raises:
|
140 |
+
BaizhongError: If the API request fails, this exception is raised with details about the error.
|
141 |
+
"""
|
142 |
+
json_data = {
|
143 |
+
"query": query,
|
144 |
+
"projectId": self.project_id,
|
145 |
+
"size": top_k,
|
146 |
+
}
|
147 |
+
if filters is not None:
|
148 |
+
filterConditions = {"filterConditions": {"bool": {"filter": {"match": filters}}}}
|
149 |
+
json_data.update(filterConditions)
|
150 |
+
res = requests.post(f"{self.base_url}/baizhong/common-search/v2/search", json=json_data)
|
151 |
+
if res.status_code == 200:
|
152 |
+
result = res.json()
|
153 |
+
if result["errCode"] != 0:
|
154 |
+
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
|
155 |
+
list_data = []
|
156 |
+
for item in result["hits"]:
|
157 |
+
content = item["_source"]["doc"]
|
158 |
+
content = base64.b64decode(content).decode("utf-8")
|
159 |
+
json_data = json.loads(content)
|
160 |
+
list_data.append(json_data)
|
161 |
+
return list_data
|
162 |
+
else:
|
163 |
+
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|
164 |
+
|
165 |
+
def add_documents(self, documents: List[Document], batch_size: int = 1, thread_count: int = 1):
|
166 |
+
"""
|
167 |
+
Add a batch of documents to the Baizhong system using multi-threading.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
documents (List[Document]): A list of Document objects to be added.
|
171 |
+
batch_size (int, optional): The size of each batch of documents (defaults to 1).
|
172 |
+
thread_count (int, optional): The number of threads to use for concurrent document addition
|
173 |
+
(defaults to 1).
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
List[Union[None, Exception]]: A list of results from the document addition process.
|
177 |
+
|
178 |
+
Note:
|
179 |
+
This function uses multi-threading to improve the efficiency of adding a large number of
|
180 |
+
documents.
|
181 |
+
|
182 |
+
"""
|
183 |
+
if type(documents[0]) == Document:
|
184 |
+
list_dicts = [item.to_dict() for item in documents]
|
185 |
+
all_data = []
|
186 |
+
for i in tqdm(range(0, len(list_dicts), batch_size)):
|
187 |
+
batch_data = list_dicts[i : i + batch_size]
|
188 |
+
all_data.append(batch_data)
|
189 |
+
with ThreadPoolExecutor(max_workers=thread_count) as executor:
|
190 |
+
res = executor.map(self._add_documents, all_data)
|
191 |
+
return list(res)
|
192 |
+
|
193 |
+
def get_document_by_id(self, doc_id):
|
194 |
+
"""
|
195 |
+
Retrieve a document from the Baizhong system by its ID.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
doc_id: The ID of the document to retrieve.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
dict: A dictionary containing information about the retrieved document.
|
202 |
+
|
203 |
+
Raises:
|
204 |
+
BaizhongError: If the API request fails, this exception is raised with details about the error.
|
205 |
+
"""
|
206 |
+
json_data = {"projectId": self.project_id, "followIndexFlag": True, "dataBody": [doc_id]}
|
207 |
+
res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/get", json=json_data)
|
208 |
+
if res.status_code == 200:
|
209 |
+
result = res.json()
|
210 |
+
if result["errCode"] != 0:
|
211 |
+
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
|
212 |
+
return result
|
213 |
+
else:
|
214 |
+
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|
215 |
+
|
216 |
+
def delete_documents(
|
217 |
+
self,
|
218 |
+
ids: Optional[List[str]] = None,
|
219 |
+
):
|
220 |
+
"""
|
221 |
+
Delete documents from the Baizhong system.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
ids (Optional[List[str]], optional): A list of document IDs to delete. If not provided,
|
225 |
+
all documents will be deleted.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
dict: A dictionary containing information about the deletion process.
|
229 |
+
|
230 |
+
Raises:
|
231 |
+
NotImplementedError: If the deletion of all documents is attempted, this exception is raised
|
232 |
+
as it is not yet implemented.
|
233 |
+
BaizhongError: If the API request fails, this exception is raised with details about the error.
|
234 |
+
"""
|
235 |
+
json_data: Dict[str, Any] = {"projectId": self.project_id, "followIndexFlag": True}
|
236 |
+
if ids is not None:
|
237 |
+
json_data["dataBody"] = ids
|
238 |
+
else:
|
239 |
+
# TODO: delete all documents
|
240 |
+
raise NotImplementedError
|
241 |
+
res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/delete", json=json_data)
|
242 |
+
if res.status_code == 200:
|
243 |
+
result = res.json()
|
244 |
+
if result["errCode"] != 0:
|
245 |
+
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
|
246 |
+
return result
|
247 |
+
else:
|
248 |
+
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|
249 |
+
|
250 |
+
def _add_documents(self, documents: List[Dict[str, Any]]):
|
251 |
+
"""
|
252 |
+
Internal method to add a batch of documents to the Baizhong system.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
documents (List[Dict[str, Any]]): A list of dictionaries representing documents to be added.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
dict: A dictionary containing information about the document addition process.
|
259 |
+
|
260 |
+
Raises:
|
261 |
+
BaizhongError: If the API request fails, this exception is raised with details about the error.
|
262 |
+
"""
|
263 |
+
json_data = {"projectId": self.project_id, "followIndexFlag": True, "dataBody": documents}
|
264 |
+
res = requests.post(f"{self.base_url}/baizhong/data-api/v2/flush/add", json=json_data)
|
265 |
+
if res.status_code == 200:
|
266 |
+
result = res.json()
|
267 |
+
if result["errCode"] != 0:
|
268 |
+
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
|
269 |
+
return result
|
270 |
+
else:
|
271 |
+
# TODO(wugaosheng): retry 3 times
|
272 |
+
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|
273 |
+
|
274 |
+
@classmethod
|
275 |
+
def delete_project(cls, project_id: int):
|
276 |
+
"""
|
277 |
+
Class method to delete a project using the Baizhong API.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
project_id (int): The ID of the project to be deleted.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
dict: A dictionary containing information about the deletion process.
|
284 |
+
|
285 |
+
Raises:
|
286 |
+
BaizhongError: If the API request fails, this exception is raised with details about the error.
|
287 |
+
"""
|
288 |
+
json_data = {"projectId": project_id}
|
289 |
+
res = requests.post(f"{cls.base_url}/baizhong/web-api/v2/project/delete", json=json_data)
|
290 |
+
if res.status_code == 200:
|
291 |
+
result = res.json()
|
292 |
+
if result["errCode"] != 0:
|
293 |
+
raise BaizhongError(message=result["errMsg"], error_code=result["errCode"])
|
294 |
+
return res.json()
|
295 |
+
else:
|
296 |
+
raise BaizhongError(message=f"request error: {res.text}", error_code=res.status_code)
|
erniebot-agent/erniebot_agent/retrieval/document.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
from typing import Any, Dict, Optional, Union
|
4 |
+
|
5 |
+
from pydantic import BaseConfig
|
6 |
+
from pydantic.dataclasses import dataclass
|
7 |
+
|
8 |
+
BaseConfig.arbitrary_types_allowed = True
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Document:
|
13 |
+
id: str
|
14 |
+
title: str
|
15 |
+
content_se: str
|
16 |
+
meta: Dict[str, Any]
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
content_se: str,
|
21 |
+
title: str,
|
22 |
+
id: Optional[str] = None,
|
23 |
+
meta: Optional[Dict[str, Any]] = None,
|
24 |
+
):
|
25 |
+
self.content_se = content_se
|
26 |
+
self.title = title
|
27 |
+
self.id = id or self._get_id()
|
28 |
+
self.meta = meta or {}
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def _get_id(cls, content_se=None) -> str:
|
32 |
+
md5_bytes = content_se.encode(encoding="UTF-8")
|
33 |
+
md5_string = hashlib.md5(md5_bytes).hexdigest()
|
34 |
+
return md5_string
|
35 |
+
|
36 |
+
def to_dict(self, field_map: Optional[Dict[str, Any]] = None) -> Dict:
|
37 |
+
"""
|
38 |
+
Convert Document to dict. An optional field_map can be supplied to
|
39 |
+
change the names of the keys in the resulting dict.
|
40 |
+
This way you can work with standardized Document objects in erniebot-agent,
|
41 |
+
but adjust the format that they are serialized / stored in other places
|
42 |
+
(e.g. elasticsearch)
|
43 |
+
Example:
|
44 |
+
|
45 |
+
```python
|
46 |
+
doc = Document(content="some text", content_type="text")
|
47 |
+
doc.to_dict(field_map={"custom_content_field": "content"})
|
48 |
+
|
49 |
+
# Returns {"custom_content_field": "some text"}
|
50 |
+
```
|
51 |
+
|
52 |
+
:param field_map: Dict with keys being the custom target keys and values
|
53 |
+
being the standard Document attributes
|
54 |
+
:return: dict with content of the Document
|
55 |
+
"""
|
56 |
+
if not field_map:
|
57 |
+
field_map = {}
|
58 |
+
|
59 |
+
inv_field_map = {v: k for k, v in field_map.items()}
|
60 |
+
_doc: Dict[str, str] = {}
|
61 |
+
for k, v in self.__dict__.items():
|
62 |
+
# Exclude internal fields (Pydantic, ...) fields from the conversion process
|
63 |
+
if k.startswith("__"):
|
64 |
+
continue
|
65 |
+
k = k if k not in inv_field_map else inv_field_map[k]
|
66 |
+
_doc[k] = v
|
67 |
+
return _doc
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def from_dict(cls, dict: Dict[str, Any], field_map: Optional[Dict[str, Any]] = None):
|
71 |
+
"""
|
72 |
+
Create Document from dict. An optional `field_map` parameter
|
73 |
+
can be supplied to adjust for custom names of the keys in the
|
74 |
+
input dict. This way you can work with standardized Document
|
75 |
+
objects in erniebot-agent, but adjust the format that
|
76 |
+
they are serialized / stored in other places (e.g. elasticsearch).
|
77 |
+
|
78 |
+
Example:
|
79 |
+
|
80 |
+
```python
|
81 |
+
my_dict = {"custom_content_field": "some text", "content_type": "text"}
|
82 |
+
Document.from_dict(my_dict, field_map={"custom_content_field": "content"})
|
83 |
+
```
|
84 |
+
|
85 |
+
:param field_map: Dict with keys being the custom target keys and values
|
86 |
+
being the standard Document attributes
|
87 |
+
:return: A Document object
|
88 |
+
"""
|
89 |
+
if not field_map:
|
90 |
+
field_map = {}
|
91 |
+
|
92 |
+
_doc = dict.copy()
|
93 |
+
init_args = ["content_se", "meta", "id", "title"]
|
94 |
+
if "meta" not in _doc.keys():
|
95 |
+
_doc["meta"] = {}
|
96 |
+
if "id" not in _doc.keys():
|
97 |
+
_doc["id"] = cls._get_id(_doc["content_se"])
|
98 |
+
# copy additional fields into "meta"
|
99 |
+
for k, v in _doc.items():
|
100 |
+
# Exclude internal fields (Pydantic, ...) fields from the conversion process
|
101 |
+
if k.startswith("__"):
|
102 |
+
continue
|
103 |
+
if k not in init_args and k not in field_map:
|
104 |
+
_doc["meta"][k] = v
|
105 |
+
# remove additional fields from top level
|
106 |
+
_new_doc = {}
|
107 |
+
for k, v in _doc.items():
|
108 |
+
if k in init_args:
|
109 |
+
_new_doc[k] = v
|
110 |
+
elif k in field_map:
|
111 |
+
k = field_map[k]
|
112 |
+
_new_doc[k] = v
|
113 |
+
return cls(**_new_doc)
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def from_json(cls, data: Union[str, Dict[str, Any]], field_map: Optional[Dict[str, Any]] = None):
|
117 |
+
if not field_map:
|
118 |
+
field_map = {}
|
119 |
+
if isinstance(data, str):
|
120 |
+
dict_data = json.loads(data)
|
121 |
+
else:
|
122 |
+
dict_data = data
|
123 |
+
return cls.from_dict(dict_data, field_map=field_map)
|
erniebot-agent/erniebot_agent/tools/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .image_generation_tool import ImageGenerationTool
|
erniebot-agent/erniebot_agent/tools/baizhong_tool.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, List, Optional, Type
|
4 |
+
|
5 |
+
from erniebot_agent.messages import AIMessage, HumanMessage
|
6 |
+
from erniebot_agent.tools.schema import ToolParameterView
|
7 |
+
from pydantic import Field
|
8 |
+
|
9 |
+
from .base import Tool
|
10 |
+
|
11 |
+
|
12 |
+
class BaizhongSearchToolInputView(ToolParameterView):
|
13 |
+
query: str = Field(description="Query")
|
14 |
+
top_k: int = Field(description="Number of results to return")
|
15 |
+
|
16 |
+
|
17 |
+
class SearchResponseDocument(ToolParameterView):
|
18 |
+
id: str = Field(description="text id")
|
19 |
+
title: str = Field(description="title")
|
20 |
+
document: str = Field(description="content")
|
21 |
+
|
22 |
+
|
23 |
+
class BaizhongSearchToolOutputView(ToolParameterView):
|
24 |
+
documents: List[SearchResponseDocument] = Field(description="research results")
|
25 |
+
|
26 |
+
|
27 |
+
class BaizhongSearchTool(Tool):
|
28 |
+
description: str = "aurora search tool"
|
29 |
+
input_type: Type[ToolParameterView] = BaizhongSearchToolInputView
|
30 |
+
ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView
|
31 |
+
|
32 |
+
def __init__(self, description, db, input_type=None, output_type=None, examples=None) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.db = db
|
35 |
+
self.description = description
|
36 |
+
if input_type is not None:
|
37 |
+
self.input_type = input_type
|
38 |
+
if output_type is not None:
|
39 |
+
self.ouptut_type = output_type
|
40 |
+
if examples is not None:
|
41 |
+
self.few_shot_examples = examples
|
42 |
+
|
43 |
+
async def __call__(self, query: str, top_k: int = 10, filters: Optional[dict[str, Any]] = None):
|
44 |
+
res = self.db.search(query, top_k, filters)
|
45 |
+
return res
|
46 |
+
|
47 |
+
@property
|
48 |
+
def examples(
|
49 |
+
self,
|
50 |
+
) -> List[Any]:
|
51 |
+
few_shot_objects: List[Any] = []
|
52 |
+
for item in self.few_shot_examples:
|
53 |
+
few_shot_objects.append(HumanMessage(item["user"]))
|
54 |
+
few_shot_objects.append(
|
55 |
+
AIMessage(
|
56 |
+
"",
|
57 |
+
function_call={
|
58 |
+
"name": self.tool_name,
|
59 |
+
"thoughts": item["thoughts"],
|
60 |
+
"arguments": item["arguments"],
|
61 |
+
},
|
62 |
+
)
|
63 |
+
)
|
64 |
+
|
65 |
+
return few_shot_objects
|
erniebot-agent/erniebot_agent/tools/base.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import tempfile
|
20 |
+
from abc import ABC, abstractmethod
|
21 |
+
from dataclasses import asdict, dataclass, field
|
22 |
+
from typing import Any, Dict, List, Optional, Type
|
23 |
+
|
24 |
+
import requests
|
25 |
+
from erniebot_agent.messages import AIMessage, FunctionCall, HumanMessage, Message
|
26 |
+
from erniebot_agent.tools.schema import (
|
27 |
+
Endpoint,
|
28 |
+
EndpointInfo,
|
29 |
+
RemoteToolView,
|
30 |
+
ToolParameterView,
|
31 |
+
scrub_dict,
|
32 |
+
)
|
33 |
+
from erniebot_agent.utils.http import url_file_exists
|
34 |
+
from erniebot_agent.utils.logging import logger
|
35 |
+
from openapi_spec_validator import validate
|
36 |
+
from openapi_spec_validator.readers import read_from_filename
|
37 |
+
from yaml import safe_dump
|
38 |
+
|
39 |
+
import erniebot
|
40 |
+
|
41 |
+
|
42 |
+
def validate_openapi_yaml(yaml_file: str) -> bool:
|
43 |
+
"""do validation on the yaml file
|
44 |
+
|
45 |
+
Args:
|
46 |
+
yaml_file (str): the path of yaml file
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
bool: whether yaml file is valid
|
50 |
+
"""
|
51 |
+
yaml_dict = read_from_filename(yaml_file)[0]
|
52 |
+
try:
|
53 |
+
validate(yaml_dict)
|
54 |
+
return True
|
55 |
+
except Exception as e: # type: ignore
|
56 |
+
logger.error(e)
|
57 |
+
return False
|
58 |
+
|
59 |
+
|
60 |
+
class BaseTool(ABC):
|
61 |
+
@abstractmethod
|
62 |
+
async def __call__(self, *args: Any, **kwds: Any) -> Any:
|
63 |
+
raise NotImplementedError
|
64 |
+
|
65 |
+
@abstractmethod
|
66 |
+
def function_call_schema(self) -> dict:
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
|
70 |
+
class Tool(BaseTool, ABC):
|
71 |
+
description: str
|
72 |
+
name: Optional[str] = None
|
73 |
+
input_type: Optional[Type[ToolParameterView]] = None
|
74 |
+
ouptut_type: Optional[Type[ToolParameterView]] = None
|
75 |
+
|
76 |
+
def __str__(self) -> str:
|
77 |
+
return "<name: {0}, description: {1}>".format(self.name, self.description)
|
78 |
+
|
79 |
+
def __repr__(self):
|
80 |
+
return self.__str__()
|
81 |
+
|
82 |
+
@property
|
83 |
+
def tool_name(self):
|
84 |
+
return self.name or self.__class__.__name__
|
85 |
+
|
86 |
+
@abstractmethod
|
87 |
+
async def __call__(self, *args: Any, **kwds: Any) -> Dict[str, Any]:
|
88 |
+
"""the body of tools
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Any:
|
92 |
+
"""
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
def function_call_schema(self) -> dict:
|
96 |
+
inputs = {
|
97 |
+
"name": self.tool_name,
|
98 |
+
"description": self.description,
|
99 |
+
"examples": [example.to_dict() for example in self.examples],
|
100 |
+
}
|
101 |
+
if self.input_type is not None:
|
102 |
+
inputs["parameters"] = self.input_type.function_call_schema()
|
103 |
+
if self.ouptut_type is not None:
|
104 |
+
inputs["responses"] = self.ouptut_type.function_call_schema()
|
105 |
+
|
106 |
+
return scrub_dict(inputs) or {}
|
107 |
+
|
108 |
+
@property
|
109 |
+
def examples(self) -> List[Message]:
|
110 |
+
return []
|
111 |
+
|
112 |
+
|
113 |
+
class RemoteTool(BaseTool):
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
tool_view: RemoteToolView,
|
117 |
+
server_url: str,
|
118 |
+
headers: dict,
|
119 |
+
examples: Optional[List[Message]] = None,
|
120 |
+
) -> None:
|
121 |
+
self.tool_view = tool_view
|
122 |
+
self.server_url = server_url
|
123 |
+
self.headers = headers
|
124 |
+
self.examples = examples
|
125 |
+
|
126 |
+
def __str__(self) -> str:
|
127 |
+
return "<name: {0}, server_url: {1}, description: {2}>".format(
|
128 |
+
self.tool_name, self.server_url, self.tool_view.description
|
129 |
+
)
|
130 |
+
|
131 |
+
def __repr__(self):
|
132 |
+
return self.__str__()
|
133 |
+
|
134 |
+
@property
|
135 |
+
def tool_name(self):
|
136 |
+
return self.tool_view.name
|
137 |
+
|
138 |
+
async def __call__(self, **tool_arguments: Dict[str, Any]) -> Any:
|
139 |
+
url = self.server_url + self.tool_view.uri
|
140 |
+
|
141 |
+
if self.tool_view.method == "get":
|
142 |
+
response = requests.get(url, params=tool_arguments, headers=self.headers)
|
143 |
+
elif self.tool_view.method == "post":
|
144 |
+
response = requests.post(url, json=tool_arguments, headers=self.headers)
|
145 |
+
elif self.tool_view.method == "put":
|
146 |
+
response = requests.put(url, json=tool_arguments, headers=self.headers)
|
147 |
+
elif self.tool_view.method == "delete":
|
148 |
+
response = requests.delete(url, json=tool_arguments, headers=self.headers)
|
149 |
+
else:
|
150 |
+
raise ValueError(f"method<{self.tool_view.method}> is invalid")
|
151 |
+
|
152 |
+
if response.status_code != 200:
|
153 |
+
raise ValueError(f"the resource is invalid, the error message is: {response.text}")
|
154 |
+
|
155 |
+
return response.json()
|
156 |
+
|
157 |
+
def function_call_schema(self) -> dict:
|
158 |
+
schema = self.tool_view.function_call_schema()
|
159 |
+
if self.examples is not None:
|
160 |
+
schema["examples"] = [example.to_dict() for example in self.examples]
|
161 |
+
|
162 |
+
return schema or {}
|
163 |
+
|
164 |
+
|
165 |
+
@dataclass
|
166 |
+
class RemoteToolkit:
|
167 |
+
"""RemoteToolkit can be converted by openapi.yaml and endpoint"""
|
168 |
+
|
169 |
+
openapi: str
|
170 |
+
info: EndpointInfo
|
171 |
+
servers: List[Endpoint]
|
172 |
+
paths: List[RemoteToolView]
|
173 |
+
|
174 |
+
component_schemas: dict[str, Type[ToolParameterView]]
|
175 |
+
headers: dict
|
176 |
+
examples: List[Message] = field(default_factory=list)
|
177 |
+
|
178 |
+
def __getitem__(self, tool_name: str):
|
179 |
+
return self.get_tool(tool_name)
|
180 |
+
|
181 |
+
def get_tools(self) -> List[RemoteTool]:
|
182 |
+
return [
|
183 |
+
RemoteTool(
|
184 |
+
path, self.servers[0].url, self.headers, examples=self.get_examples_by_name(path.name)
|
185 |
+
)
|
186 |
+
for path in self.paths
|
187 |
+
]
|
188 |
+
|
189 |
+
def get_examples_by_name(self, tool_name: str) -> List[Message]:
|
190 |
+
"""get examples by tool-name
|
191 |
+
|
192 |
+
Args:
|
193 |
+
tool_name (str): the name of the tool
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
List[Message]: the messages
|
197 |
+
"""
|
198 |
+
# 1. split messages
|
199 |
+
tool_examples: List[List[Message]] = []
|
200 |
+
examples: List[Message] = []
|
201 |
+
for example in self.examples:
|
202 |
+
if isinstance(example, HumanMessage):
|
203 |
+
if len(examples) == 0:
|
204 |
+
examples.append(example)
|
205 |
+
else:
|
206 |
+
tool_examples.append(examples)
|
207 |
+
examples = [example]
|
208 |
+
else:
|
209 |
+
examples.append(example)
|
210 |
+
|
211 |
+
if len(examples) > 0:
|
212 |
+
tool_examples.append(examples)
|
213 |
+
|
214 |
+
final_exampels: List[Message] = []
|
215 |
+
# 2. find the target tool examples or empty messages
|
216 |
+
for examples in tool_examples:
|
217 |
+
tool_names = [
|
218 |
+
example.function_call.get("name", None)
|
219 |
+
for example in examples
|
220 |
+
if isinstance(example, AIMessage) and example.function_call is not None
|
221 |
+
]
|
222 |
+
tool_names = [name for name in tool_names if name]
|
223 |
+
|
224 |
+
if tool_name in tool_names:
|
225 |
+
final_exampels.extend(examples)
|
226 |
+
|
227 |
+
return final_exampels
|
228 |
+
|
229 |
+
def get_tool(self, tool_name: str) -> RemoteTool:
|
230 |
+
paths = [path for path in self.paths if path.name == tool_name]
|
231 |
+
assert len(paths) == 1, f"tool<{tool_name}> not found in paths"
|
232 |
+
return RemoteTool(
|
233 |
+
paths[0], self.servers[0].url, self.headers, examples=self.get_examples_by_name(tool_name)
|
234 |
+
)
|
235 |
+
|
236 |
+
def to_openapi_dict(self) -> dict:
|
237 |
+
"""convert plugin schema to openapi spec dict"""
|
238 |
+
spec_dict = {
|
239 |
+
"openapi": self.openapi,
|
240 |
+
"info": asdict(self.info),
|
241 |
+
"servers": [asdict(server) for server in self.servers],
|
242 |
+
"paths": {tool_view.uri: tool_view.to_openapi_dict() for tool_view in self.paths},
|
243 |
+
"components": {
|
244 |
+
"schemas": {
|
245 |
+
uri: parameters_view.to_openapi_dict()
|
246 |
+
for uri, parameters_view in self.component_schemas.items()
|
247 |
+
}
|
248 |
+
},
|
249 |
+
}
|
250 |
+
return scrub_dict(spec_dict, remove_empty_dict=True) or {}
|
251 |
+
|
252 |
+
def to_openapi_file(self, file: str):
|
253 |
+
"""generate openapi configuration file
|
254 |
+
|
255 |
+
Args:
|
256 |
+
file (str): the path of the openapi yaml file
|
257 |
+
"""
|
258 |
+
spec_dict = self.to_openapi_dict()
|
259 |
+
with open(file, "w+", encoding="utf-8") as f:
|
260 |
+
safe_dump(spec_dict, f, indent=4)
|
261 |
+
|
262 |
+
@classmethod
|
263 |
+
def from_openapi_dict(
|
264 |
+
cls, openapi_dict: Dict[str, Any], access_token: Optional[str] = None
|
265 |
+
) -> RemoteToolkit:
|
266 |
+
info = EndpointInfo(**openapi_dict["info"])
|
267 |
+
servers = [Endpoint(**server) for server in openapi_dict.get("servers", [])]
|
268 |
+
|
269 |
+
# components
|
270 |
+
component_schemas = openapi_dict["components"]["schemas"]
|
271 |
+
fields = {}
|
272 |
+
for schema_name, schema in component_schemas.items():
|
273 |
+
parameter_view = ToolParameterView.from_openapi_dict(schema_name, schema)
|
274 |
+
fields[schema_name] = parameter_view
|
275 |
+
|
276 |
+
# paths
|
277 |
+
paths = []
|
278 |
+
for path, path_info in openapi_dict.get("paths", {}).items():
|
279 |
+
for method, path_method_info in path_info.items():
|
280 |
+
paths.append(
|
281 |
+
RemoteToolView.from_openapi_dict(
|
282 |
+
uri=path,
|
283 |
+
method=method,
|
284 |
+
path_info=path_method_info,
|
285 |
+
parameters_views=fields,
|
286 |
+
)
|
287 |
+
)
|
288 |
+
|
289 |
+
return RemoteToolkit(
|
290 |
+
openapi=openapi_dict["openapi"],
|
291 |
+
info=info,
|
292 |
+
servers=servers,
|
293 |
+
paths=paths,
|
294 |
+
component_schemas=fields,
|
295 |
+
headers=cls._get_authorization_headers(access_token),
|
296 |
+
) # type: ignore
|
297 |
+
|
298 |
+
@classmethod
|
299 |
+
def from_openapi_file(cls, file: str, access_token: Optional[str] = None) -> RemoteToolkit:
|
300 |
+
"""only support openapi v3.0.1
|
301 |
+
|
302 |
+
Args:
|
303 |
+
file (str): the path of openapi yaml file
|
304 |
+
access_token (Optional[str]): the path of openapi yaml file
|
305 |
+
"""
|
306 |
+
if not validate_openapi_yaml(file):
|
307 |
+
raise ValueError(f"invalid openapi yaml file: {file}")
|
308 |
+
|
309 |
+
spec_dict, _ = read_from_filename(file)
|
310 |
+
return cls.from_openapi_dict(spec_dict, access_token=access_token)
|
311 |
+
|
312 |
+
@classmethod
|
313 |
+
def _get_authorization_headers(cls, access_token: Optional[str]) -> dict:
|
314 |
+
if access_token is None:
|
315 |
+
access_token = erniebot.access_token
|
316 |
+
|
317 |
+
headers = {"Content-Type": "application/json"}
|
318 |
+
if access_token is None:
|
319 |
+
logger.warning("access_token is NOT provided, this may cause 403 HTTP error..")
|
320 |
+
else:
|
321 |
+
headers["Authorization"] = f"token {access_token}"
|
322 |
+
return headers
|
323 |
+
|
324 |
+
@classmethod
|
325 |
+
def from_url(cls, url: str, access_token: Optional[str] = None) -> RemoteToolkit:
|
326 |
+
# 1. download openapy.yaml file to temp directory
|
327 |
+
if not url.endswith("/"):
|
328 |
+
url += "/"
|
329 |
+
openapi_yaml_url = url + ".well-known/openapi.yaml"
|
330 |
+
|
331 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
332 |
+
response = requests.get(openapi_yaml_url, headers=cls._get_authorization_headers(access_token))
|
333 |
+
if response.status_code != 200:
|
334 |
+
raise ValueError(f"the resource is invalid, the error message is: {response.text}")
|
335 |
+
|
336 |
+
file_content = response.content.decode("utf-8")
|
337 |
+
if not file_content.strip():
|
338 |
+
raise ValueError(f"the content is empty from: {openapi_yaml_url}")
|
339 |
+
|
340 |
+
file_path = os.path.join(temp_dir, "openapi.yaml")
|
341 |
+
with open(file_path, "w+", encoding="utf-8") as f:
|
342 |
+
f.write(file_content)
|
343 |
+
|
344 |
+
toolkit = RemoteToolkit.from_openapi_file(file_path, access_token=access_token)
|
345 |
+
for server in toolkit.servers:
|
346 |
+
server.url = url
|
347 |
+
|
348 |
+
toolkit.examples = cls.load_remote_examples_yaml(url, access_token)
|
349 |
+
|
350 |
+
return toolkit
|
351 |
+
|
352 |
+
@classmethod
|
353 |
+
def load_remote_examples_yaml(cls, url: str, access_token: Optional[str] = None) -> List[Message]:
|
354 |
+
"""load remote examples by url: url/.well-known/examples.yaml
|
355 |
+
|
356 |
+
Args:
|
357 |
+
url (str): the base url of the remote toolkit
|
358 |
+
"""
|
359 |
+
if not url.endswith("/"):
|
360 |
+
url += "/"
|
361 |
+
examples_yaml_url = url + ".well-known/examples.yaml"
|
362 |
+
if not url_file_exists(examples_yaml_url, cls._get_authorization_headers(access_token)):
|
363 |
+
return []
|
364 |
+
|
365 |
+
examples = []
|
366 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
367 |
+
response = requests.get(examples_yaml_url, headers=cls._get_authorization_headers(access_token))
|
368 |
+
if response.status_code != 200:
|
369 |
+
raise ValueError(
|
370 |
+
f"Invalid resource, status_code: {response.status_code}, error message: {response.text}"
|
371 |
+
)
|
372 |
+
|
373 |
+
file_content = response.content.decode("utf-8")
|
374 |
+
if not file_content.strip():
|
375 |
+
raise ValueError(f"the content is empty from: {examples_yaml_url}")
|
376 |
+
|
377 |
+
file_path = os.path.join(temp_dir, "examples.yaml")
|
378 |
+
with open(file_path, "w+", encoding="utf-8") as f:
|
379 |
+
f.write(file_content)
|
380 |
+
|
381 |
+
examples = cls.load_examples_yaml(file_path)
|
382 |
+
|
383 |
+
return examples
|
384 |
+
|
385 |
+
@classmethod
|
386 |
+
def load_examples_dict(cls, examples_dict: Dict[str, Any]) -> List[Message]:
|
387 |
+
messages: List[Message] = []
|
388 |
+
for examples in examples_dict["examples"]:
|
389 |
+
examples = examples["context"]
|
390 |
+
for example in examples:
|
391 |
+
if "user" == example["role"]:
|
392 |
+
messages.append(HumanMessage(example["content"]))
|
393 |
+
elif "bot" in example["role"]:
|
394 |
+
plugin = example["plugin"]
|
395 |
+
if "operationId" in plugin:
|
396 |
+
function_call: FunctionCall = {
|
397 |
+
"name": plugin["operationId"],
|
398 |
+
"thoughts": plugin["thoughts"],
|
399 |
+
"arguments": json.dumps(plugin["requestArguments"], ensure_ascii=False),
|
400 |
+
}
|
401 |
+
else:
|
402 |
+
function_call = {
|
403 |
+
"name": "",
|
404 |
+
"thoughts": plugin["thoughts"],
|
405 |
+
"arguments": "{}",
|
406 |
+
} # type: ignore
|
407 |
+
messages.append(AIMessage("", function_call=function_call))
|
408 |
+
else:
|
409 |
+
raise ValueError(f"invald role: <{example['role']}>")
|
410 |
+
return messages
|
411 |
+
|
412 |
+
@classmethod
|
413 |
+
def load_examples_yaml(cls, file: str) -> List[Message]:
|
414 |
+
"""load examples from yaml file
|
415 |
+
|
416 |
+
Args:
|
417 |
+
file (str): the path of examples file
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
List[Message]: the list of messages
|
421 |
+
"""
|
422 |
+
content: dict = read_from_filename(file)[0]
|
423 |
+
if len(content) == 0 or "examples" not in content:
|
424 |
+
raise ValueError("invalid examples configuration file")
|
425 |
+
return cls.load_examples_dict(content)
|
426 |
+
|
427 |
+
def function_call_schemas(self) -> List[dict]:
|
428 |
+
return [tool.function_call_schema() for tool in self.get_tools()]
|
erniebot-agent/erniebot_agent/tools/calculator_tool.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Dict, List, Type
|
4 |
+
|
5 |
+
from erniebot_agent.messages import AIMessage, HumanMessage, Message
|
6 |
+
from erniebot_agent.tools.schema import ToolParameterView
|
7 |
+
from pydantic import Field
|
8 |
+
|
9 |
+
from .base import Tool
|
10 |
+
|
11 |
+
|
12 |
+
class CalculatorToolInputView(ToolParameterView):
|
13 |
+
math_formula: str = Field(description='标准的数学公式,例如:"2+3"、"3 - 4 * 6", "(3 + 4) * (6 + 4)" 等。 ')
|
14 |
+
|
15 |
+
|
16 |
+
class CalculatorToolOutputView(ToolParameterView):
|
17 |
+
formula_result: float = Field(description="数学公式计算的结果")
|
18 |
+
|
19 |
+
|
20 |
+
class CalculatorTool(Tool):
|
21 |
+
description: str = "CalculatorTool用于执行数学公式计算"
|
22 |
+
input_type: Type[ToolParameterView] = CalculatorToolInputView
|
23 |
+
ouptut_type: Type[ToolParameterView] = CalculatorToolOutputView
|
24 |
+
|
25 |
+
async def __call__(self, math_formula: str) -> Dict[str, float]:
|
26 |
+
return {"formula_result": eval(math_formula)}
|
27 |
+
|
28 |
+
@property
|
29 |
+
def examples(self) -> List[Message]:
|
30 |
+
return [
|
31 |
+
HumanMessage("请告诉我三加六等于多少?"),
|
32 |
+
AIMessage(
|
33 |
+
"",
|
34 |
+
function_call={
|
35 |
+
"name": self.tool_name,
|
36 |
+
"thoughts": f"用户想知道3加6等于多少,我可以使用{self.tool_name}工具来计算公式,其中`math_formula`字段的内容为:'3+6'。",
|
37 |
+
"arguments": '{"math_formula": "3+6"}',
|
38 |
+
},
|
39 |
+
token_usage={
|
40 |
+
"prompt_tokens": 5,
|
41 |
+
"completion_tokens": 7,
|
42 |
+
}, # TODO: Functional AIMessage will not add in the memory, will it add token_usage?
|
43 |
+
),
|
44 |
+
HumanMessage("一加八再乘以5是多少?"),
|
45 |
+
AIMessage(
|
46 |
+
"",
|
47 |
+
function_call={
|
48 |
+
"name": self.tool_name,
|
49 |
+
"thoughts": f"用户想知道1加8再乘5等于多少,我可以使用{self.tool_name}工具来计算公式,"
|
50 |
+
"其中`math_formula`字段的内容为:'(1+8)*5'。",
|
51 |
+
"arguments": '{"math_formula": "(1+8)*5"}',
|
52 |
+
},
|
53 |
+
token_usage={"prompt_tokens": 5, "completion_tokens": 7}, # For test only
|
54 |
+
),
|
55 |
+
HumanMessage("我想知道十二除以四再加五等于多少?"),
|
56 |
+
AIMessage(
|
57 |
+
"",
|
58 |
+
function_call={
|
59 |
+
"name": self.tool_name,
|
60 |
+
"thoughts": f"用户想知道12除以4再加5等于多少,我可以使用{self.tool_name}工具来计算公式,"
|
61 |
+
"其中`math_formula`字段的内容为:'12/4+5'。",
|
62 |
+
"arguments": '{"math_formula": "12/4+5"}',
|
63 |
+
},
|
64 |
+
token_usage={"prompt_tokens": 5, "completion_tokens": 7}, # For test only
|
65 |
+
),
|
66 |
+
]
|
erniebot-agent/erniebot_agent/tools/current_time_tool.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from datetime import datetime
|
4 |
+
from typing import Dict, List, Type
|
5 |
+
|
6 |
+
from erniebot_agent.messages import AIMessage, HumanMessage, Message
|
7 |
+
from erniebot_agent.tools.schema import ToolParameterView
|
8 |
+
from pydantic import Field
|
9 |
+
|
10 |
+
from .base import Tool
|
11 |
+
|
12 |
+
|
13 |
+
class CurrentTimeToolOutputView(ToolParameterView):
|
14 |
+
current_time: str = Field(description="当前时间")
|
15 |
+
|
16 |
+
|
17 |
+
class CurrentTimeTool(Tool):
|
18 |
+
description: str = "CurrentTimeTool 用于获取当前时间"
|
19 |
+
ouptut_type: Type[ToolParameterView] = CurrentTimeToolOutputView
|
20 |
+
|
21 |
+
async def __call__(self) -> Dict[str, str]:
|
22 |
+
return {"current_time": datetime.strftime(datetime.now(), "%Y年%m月%d号 %点:%分:%秒")}
|
23 |
+
|
24 |
+
@property
|
25 |
+
def examples(self) -> List[Message]:
|
26 |
+
return [
|
27 |
+
HumanMessage("现在几点钟了"),
|
28 |
+
AIMessage(
|
29 |
+
"",
|
30 |
+
function_call={
|
31 |
+
"name": self.tool_name,
|
32 |
+
"thoughts": f"用户想知道现在几点了,我可以使用{self.tool_name}来获取当前时间,并从其中获得当前小时时间。",
|
33 |
+
"arguments": "{}",
|
34 |
+
},
|
35 |
+
token_usage={"prompt_tokens": 5, "completion_tokens": 7}, # For test only
|
36 |
+
),
|
37 |
+
HumanMessage("现在是什么时候?"),
|
38 |
+
AIMessage(
|
39 |
+
"",
|
40 |
+
function_call={
|
41 |
+
"name": self.tool_name,
|
42 |
+
"thoughts": f"用户想知道现在几点了,我可以使用{self.tool_name}来获取当前时间",
|
43 |
+
"arguments": "{}",
|
44 |
+
},
|
45 |
+
token_usage={"prompt_tokens": 5, "completion_tokens": 7}, # For test only
|
46 |
+
),
|
47 |
+
]
|
erniebot-agent/erniebot_agent/tools/image_generation_tool.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import os
|
18 |
+
import uuid
|
19 |
+
from typing import Any, Dict, List, Optional, Type
|
20 |
+
|
21 |
+
from erniebot_agent.messages import AIMessage, HumanMessage, Message
|
22 |
+
from erniebot_agent.tools.base import Tool
|
23 |
+
from erniebot_agent.tools.schema import ToolParameterView
|
24 |
+
from erniebot_agent.utils.common import download_file, get_cache_dir
|
25 |
+
from pydantic import Field
|
26 |
+
|
27 |
+
import erniebot
|
28 |
+
|
29 |
+
|
30 |
+
class ImageGenerationInputView(ToolParameterView):
|
31 |
+
prompt: str = Field(description="描述图像内容、风格的文本。例如:生成一张月亮的照片,月亮很圆。")
|
32 |
+
width: int = Field(description="生成图片的宽度")
|
33 |
+
height: int = Field(description="生成图片的高度")
|
34 |
+
image_num: int = Field(description="生成图片的数量")
|
35 |
+
|
36 |
+
|
37 |
+
class ImageGenerationOutputView(ToolParameterView):
|
38 |
+
image_path: str = Field(description="图片在本地机器上的保存路径")
|
39 |
+
|
40 |
+
|
41 |
+
class ImageGenerationTool(Tool):
|
42 |
+
description: str = "AI作图、生成图片、画图的工具"
|
43 |
+
input_type: Type[ToolParameterView] = ImageGenerationInputView
|
44 |
+
ouptut_type: Type[ToolParameterView] = ImageGenerationOutputView
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
yinian_access_token: Optional[str] = None,
|
49 |
+
yinian_ak: Optional[str] = None,
|
50 |
+
yinian_sk: Optional[str] = None,
|
51 |
+
) -> None:
|
52 |
+
self.config: Dict[str, Optional[Any]]
|
53 |
+
if yinian_access_token is not None:
|
54 |
+
self.config = {"api_type": "yinian", "access_token": yinian_access_token}
|
55 |
+
elif yinian_ak is not None and yinian_sk is not None:
|
56 |
+
self.config = {"api_type": "yinian", "ak": yinian_ak, "sk": yinian_sk}
|
57 |
+
else:
|
58 |
+
raise ValueError("Please set the yinian_access_token, or set yinian_ak and yinian_sk")
|
59 |
+
|
60 |
+
async def __call__(
|
61 |
+
self,
|
62 |
+
prompt: str,
|
63 |
+
width: int = 512,
|
64 |
+
height: int = 512,
|
65 |
+
image_num: int = 1,
|
66 |
+
) -> Dict[str, List[str]]:
|
67 |
+
response = erniebot.Image.create(
|
68 |
+
model="ernie-vilg-v2",
|
69 |
+
prompt=prompt,
|
70 |
+
width=width,
|
71 |
+
height=height,
|
72 |
+
image_num=image_num,
|
73 |
+
_config_=self.config,
|
74 |
+
)
|
75 |
+
|
76 |
+
image_path = []
|
77 |
+
cache_dir = get_cache_dir()
|
78 |
+
for item in response["data"]["sub_task_result_list"]:
|
79 |
+
image_url = item["final_image_list"][0]["img_url"]
|
80 |
+
save_path = os.path.join(cache_dir, f"img_{uuid.uuid1()}.png")
|
81 |
+
download_file(image_url, save_path)
|
82 |
+
image_path.append(save_path)
|
83 |
+
return {"image_path": image_path}
|
84 |
+
|
85 |
+
@property
|
86 |
+
def examples(self) -> List[Message]:
|
87 |
+
return [
|
88 |
+
HumanMessage("画一张小狗的图片,图像高度512,图像宽度512"),
|
89 |
+
AIMessage(
|
90 |
+
"",
|
91 |
+
function_call={
|
92 |
+
"name": "ImageGenerationTool",
|
93 |
+
"thoughts": "用户需要我生成一张小狗的图片,图像高度为512,宽度为512。我可以使用ImageGenerationTool工具来满足用户的需求。",
|
94 |
+
"arguments": '{"prompt":"画一张小狗的图片,图像高度512,图像宽度512",'
|
95 |
+
'"width":512,"height":512,"image_num":1}',
|
96 |
+
},
|
97 |
+
),
|
98 |
+
HumanMessage("生成两张天空的图片"),
|
99 |
+
AIMessage(
|
100 |
+
"",
|
101 |
+
function_call={
|
102 |
+
"name": self.tool_name,
|
103 |
+
"thoughts": "用户想要生成两张天空的图片,我需要调用ImageGenerationTool工具的call接口,"
|
104 |
+
"并设置prompt为'生成两张天空的图片',width和height可以默认为512,image_num默认为2。",
|
105 |
+
"arguments": '{"prompt":"生成两张天空的图片","width":512,"height":512,"image_num":2}',
|
106 |
+
},
|
107 |
+
),
|
108 |
+
HumanMessage("使用AI作图工具,生成1张小猫的图片,高度和高度是1024"),
|
109 |
+
AIMessage(
|
110 |
+
"",
|
111 |
+
function_call={
|
112 |
+
"name": self.tool_name,
|
113 |
+
"thoughts": "用户需要生成一张小猫的图片,高度和宽度都是1024。我可以使用ImageGenerationTool工具来满足用户的需求。",
|
114 |
+
"arguments": '{"prompt":"生成一张小猫的照片。","width":1024,"height":1024,"image_num":1}',
|
115 |
+
},
|
116 |
+
),
|
117 |
+
]
|
erniebot-agent/erniebot_agent/tools/schema.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Any, Dict, List, Optional, Type, get_args
|
20 |
+
|
21 |
+
from erniebot_agent.utils.logging import logger
|
22 |
+
from pydantic import BaseModel, Field, create_model
|
23 |
+
from pydantic.fields import FieldInfo
|
24 |
+
|
25 |
+
INVALID_FIELD_NAME = "__invalid_field_name__"
|
26 |
+
|
27 |
+
|
28 |
+
def is_optional_type(type: Type):
|
29 |
+
args = get_args(type)
|
30 |
+
if len(args) == 0:
|
31 |
+
return False
|
32 |
+
|
33 |
+
return len([arg for arg in args if arg is None.__class__]) > 0
|
34 |
+
|
35 |
+
|
36 |
+
def get_typing_list_type(type):
|
37 |
+
"""get typing.List[T] element type
|
38 |
+
|
39 |
+
Args:
|
40 |
+
type (typing.List): Generics type
|
41 |
+
"""
|
42 |
+
# 1. checking list type
|
43 |
+
if getattr(type, "_name", None) != "List":
|
44 |
+
return None
|
45 |
+
|
46 |
+
arg_type = get_args(type)[0]
|
47 |
+
return json_type(arg_type)
|
48 |
+
|
49 |
+
|
50 |
+
def json_type(type: Optional[Type[object]] = None):
|
51 |
+
if type is None:
|
52 |
+
return "object"
|
53 |
+
|
54 |
+
mapping = {
|
55 |
+
int: "integer",
|
56 |
+
str: "string",
|
57 |
+
list: "array",
|
58 |
+
List: "array",
|
59 |
+
float: "number",
|
60 |
+
ToolParameterView: "object",
|
61 |
+
}
|
62 |
+
|
63 |
+
if inspect.isclass(type) and issubclass(type, ToolParameterView):
|
64 |
+
return "object"
|
65 |
+
|
66 |
+
if getattr(type, "_name", None) == "List":
|
67 |
+
return "array"
|
68 |
+
|
69 |
+
if type not in mapping:
|
70 |
+
args = [arg for arg in get_args(type) if arg is not None.__class__]
|
71 |
+
if len(args) > 1 or len(args) == 0:
|
72 |
+
raise ValueError(
|
73 |
+
"only support simple type: FieldType=int/str/float/ToolParameterView, "
|
74 |
+
"so the target type should be one of: FieldType, List[FieldType], "
|
75 |
+
f"Optional[FieldType], but receive {type}"
|
76 |
+
)
|
77 |
+
type = args[0]
|
78 |
+
|
79 |
+
if type in mapping:
|
80 |
+
return mapping[type]
|
81 |
+
|
82 |
+
if inspect.isclass(type) and issubclass(type, ToolParameterView):
|
83 |
+
return "object"
|
84 |
+
|
85 |
+
return str(type)
|
86 |
+
|
87 |
+
|
88 |
+
def python_type_from_json_type(json_type_dict: dict) -> Type[object]:
|
89 |
+
simple_types = {"integer": int, "string": str, "number": float, "object": ToolParameterView}
|
90 |
+
if json_type_dict["type"] in simple_types:
|
91 |
+
return simple_types[json_type_dict["type"]]
|
92 |
+
|
93 |
+
assert (
|
94 |
+
json_type_dict["type"] == "array"
|
95 |
+
), f"only support simple_types<{','.join(simple_types)}> and array type"
|
96 |
+
assert "type" in json_type_dict["items"], "<items> field must be defined when 'type'=array"
|
97 |
+
|
98 |
+
json_type_value = json_type_dict["items"]["type"]
|
99 |
+
if json_type_value == "string":
|
100 |
+
return List[str]
|
101 |
+
if json_type_value == "integer":
|
102 |
+
return List[int]
|
103 |
+
if json_type_value == "number":
|
104 |
+
return List[float]
|
105 |
+
if json_type_value == "object":
|
106 |
+
return List[ToolParameterView]
|
107 |
+
|
108 |
+
raise ValueError(f"unsupported data type: {json_type_value}")
|
109 |
+
|
110 |
+
|
111 |
+
def scrub_dict(d: dict, remove_empty_dict: bool = False) -> Optional[dict]:
|
112 |
+
"""remove empty Value node,
|
113 |
+
|
114 |
+
function_call_schema: require
|
115 |
+
|
116 |
+
Args:
|
117 |
+
d (dict): the instance of dictionary
|
118 |
+
remove_empty_dict (bool): whether remove empty dict
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
dict: the dictionary data after slimming down
|
122 |
+
"""
|
123 |
+
if type(d) is dict:
|
124 |
+
result = {}
|
125 |
+
for k, v in d.items():
|
126 |
+
v = scrub_dict(v, remove_empty_dict)
|
127 |
+
if v is not None:
|
128 |
+
result[k] = v
|
129 |
+
|
130 |
+
if len(result) == 0:
|
131 |
+
if not remove_empty_dict:
|
132 |
+
return {}
|
133 |
+
return None
|
134 |
+
|
135 |
+
return result
|
136 |
+
elif isinstance(d, list):
|
137 |
+
return [scrub_dict(item, remove_empty_dict) for item in d] # type: ignore
|
138 |
+
else:
|
139 |
+
return d
|
140 |
+
|
141 |
+
|
142 |
+
class OpenAPIProperty(BaseModel):
|
143 |
+
type: str
|
144 |
+
description: Optional[str] = None
|
145 |
+
required: Optional[List[str]] = None
|
146 |
+
items: dict = Field(default_factory=dict)
|
147 |
+
properties: dict = Field(default_factory=dict)
|
148 |
+
|
149 |
+
|
150 |
+
def get_field_openapi_property(field_info: FieldInfo) -> OpenAPIProperty:
|
151 |
+
"""convert pydantic FieldInfo instance to OpenAPIProperty value
|
152 |
+
|
153 |
+
Args:
|
154 |
+
field_info (FieldInfo): the field instance
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
OpenAPIProperty: the converted OpenAPI Property
|
158 |
+
"""
|
159 |
+
typing_list_type = get_typing_list_type(field_info.annotation)
|
160 |
+
if typing_list_type is not None:
|
161 |
+
field_type = "array"
|
162 |
+
elif is_optional_type(field_info.annotation):
|
163 |
+
field_type = json_type(get_args(field_info.annotation)[0])
|
164 |
+
else:
|
165 |
+
field_type = json_type(field_info.annotation)
|
166 |
+
|
167 |
+
property = {
|
168 |
+
"type": field_type,
|
169 |
+
"description": field_info.description,
|
170 |
+
}
|
171 |
+
|
172 |
+
if property["type"] == "array":
|
173 |
+
if typing_list_type == "object":
|
174 |
+
list_type: Type[ToolParameterView] = get_args(field_info.annotation)[0]
|
175 |
+
property["items"] = list_type.to_openapi_dict()
|
176 |
+
else:
|
177 |
+
property["items"] = {"type": typing_list_type}
|
178 |
+
elif property["type"] == "object":
|
179 |
+
if is_optional_type(field_info.annotation):
|
180 |
+
field_type_class: Type[ToolParameterView] = get_args(field_info.annotation)[0]
|
181 |
+
else:
|
182 |
+
field_type_class = field_info.annotation
|
183 |
+
|
184 |
+
openapi_dict = field_type_class.to_openapi_dict()
|
185 |
+
property.update(openapi_dict)
|
186 |
+
|
187 |
+
property["description"] = property.get("description", "")
|
188 |
+
return OpenAPIProperty(**property)
|
189 |
+
|
190 |
+
|
191 |
+
class ToolParameterView(BaseModel):
|
192 |
+
@classmethod
|
193 |
+
def from_openapi_dict(cls, name, schema: dict) -> Type[ToolParameterView]:
|
194 |
+
"""parse openapi component schemas to ParameterView
|
195 |
+
Args:
|
196 |
+
response_or_returns (dict): the content of status code
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
_type_: _description_
|
200 |
+
"""
|
201 |
+
|
202 |
+
# TODO(wj-Mcat): to load Optional field
|
203 |
+
fields = {}
|
204 |
+
for field_name, field_dict in schema.get("properties", {}).items():
|
205 |
+
field_type = python_type_from_json_type(field_dict)
|
206 |
+
|
207 |
+
if field_type is List[ToolParameterView]:
|
208 |
+
SubParameterView: Type[ToolParameterView] = ToolParameterView.from_openapi_dict(
|
209 |
+
field_name, field_dict["items"]
|
210 |
+
)
|
211 |
+
field_type = List[SubParameterView] # type: ignore
|
212 |
+
|
213 |
+
# TODO(wj-Mcat): remove supporting for `summary` field
|
214 |
+
if "summary" in field_dict:
|
215 |
+
description = field_dict["summary"]
|
216 |
+
logger.info("`summary` field will be deprecated, please use `description`")
|
217 |
+
|
218 |
+
if "description" in field_dict:
|
219 |
+
logger.info("`description` field will be used instead of `summary`")
|
220 |
+
description = field_dict["description"]
|
221 |
+
else:
|
222 |
+
description = field_dict.get("description", None)
|
223 |
+
|
224 |
+
description = description or ""
|
225 |
+
|
226 |
+
field = FieldInfo(annotation=field_type, description=description)
|
227 |
+
|
228 |
+
# TODO(wj-Mcat): to handle list field required & not-required
|
229 |
+
# if get_typing_list_type(field_type) is not None:
|
230 |
+
# field.default_factory = list
|
231 |
+
|
232 |
+
fields[field_name] = (field_type, field)
|
233 |
+
|
234 |
+
return create_model("OpenAPIParameterView", __base__=ToolParameterView, **fields)
|
235 |
+
|
236 |
+
@classmethod
|
237 |
+
def to_openapi_dict(cls) -> dict:
|
238 |
+
"""convert ParametersView to openapi spec dict
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
dict: schema of openapi
|
242 |
+
"""
|
243 |
+
|
244 |
+
required_names, properties = [], {}
|
245 |
+
for field_name, field_info in cls.model_fields.items():
|
246 |
+
if field_info.is_required() and not is_optional_type(field_info.annotation):
|
247 |
+
required_names.append(field_name)
|
248 |
+
|
249 |
+
properties[field_name] = dict(get_field_openapi_property(field_info))
|
250 |
+
|
251 |
+
result = {
|
252 |
+
"type": "object",
|
253 |
+
"properties": properties,
|
254 |
+
}
|
255 |
+
if len(required_names) > 0:
|
256 |
+
result["required"] = required_names
|
257 |
+
result = scrub_dict(result, remove_empty_dict=True) # type: ignore
|
258 |
+
return result or {}
|
259 |
+
|
260 |
+
@classmethod
|
261 |
+
def function_call_schema(cls) -> dict:
|
262 |
+
"""get function_call schame
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
dict: the schema of function_call
|
266 |
+
"""
|
267 |
+
return cls.to_openapi_dict()
|
268 |
+
|
269 |
+
@classmethod
|
270 |
+
def from_dict(cls, field_map: Dict[str, Any]):
|
271 |
+
"""
|
272 |
+
Class method to create a Pydantic model dynamically based on a dictionary.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
field_map (Dict[str, Any]): A dictionary mapping field names to their corresponding type
|
276 |
+
and description.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
PydanticModel: A dynamically created Pydantic model with fields specified by the
|
280 |
+
input dictionary.
|
281 |
+
|
282 |
+
Note:
|
283 |
+
This method is used to create a Pydantic model dynamically based on the provided dictionary,
|
284 |
+
where each field's type and description are specified in the input.
|
285 |
+
|
286 |
+
"""
|
287 |
+
fields = {}
|
288 |
+
for field_name, field_dict in field_map.items():
|
289 |
+
field_type = field_dict["type"]
|
290 |
+
description = field_dict["description"]
|
291 |
+
field = FieldInfo(annotation=field_type, description=description)
|
292 |
+
fields[field_name] = (field_type, field)
|
293 |
+
return create_model(cls.__name__, __base__=ToolParameterView, **fields)
|
294 |
+
|
295 |
+
|
296 |
+
@dataclass
|
297 |
+
class RemoteToolView:
|
298 |
+
uri: str
|
299 |
+
method: str
|
300 |
+
name: str
|
301 |
+
description: str
|
302 |
+
parameters: Optional[Type[ToolParameterView]] = None
|
303 |
+
parameters_description: Optional[str] = None
|
304 |
+
returns: Optional[Type[ToolParameterView]] = None
|
305 |
+
returns_description: Optional[str] = None
|
306 |
+
|
307 |
+
returns_ref_uri: Optional[str] = None
|
308 |
+
parameters_ref_uri: Optional[str] = None
|
309 |
+
|
310 |
+
def to_openapi_dict(self):
|
311 |
+
result = {
|
312 |
+
"operationId": self.name,
|
313 |
+
"description": self.description,
|
314 |
+
}
|
315 |
+
if self.returns is not None:
|
316 |
+
response = {
|
317 |
+
"200": {
|
318 |
+
"description": self.returns_description,
|
319 |
+
"content": {
|
320 |
+
"application/json": {
|
321 |
+
"schema": {"$ref": "#/components/schemas/" + (self.returns_ref_uri or "")}
|
322 |
+
}
|
323 |
+
},
|
324 |
+
}
|
325 |
+
}
|
326 |
+
result["responses"] = response
|
327 |
+
|
328 |
+
if self.parameters is not None:
|
329 |
+
parameters = {
|
330 |
+
"required": True,
|
331 |
+
"content": {
|
332 |
+
"application/json": {
|
333 |
+
"schema": {"$ref": "#/components/schemas/" + (self.parameters_ref_uri or "")}
|
334 |
+
}
|
335 |
+
},
|
336 |
+
}
|
337 |
+
result["requestBody"] = parameters
|
338 |
+
return {self.method: result}
|
339 |
+
|
340 |
+
@staticmethod
|
341 |
+
def from_openapi_dict(
|
342 |
+
uri: str, method: str, path_info: dict, parameters_views: dict[str, Type[ToolParameterView]]
|
343 |
+
) -> RemoteToolView:
|
344 |
+
"""construct RemoteToolView from openapi spec-dict info
|
345 |
+
|
346 |
+
Args:
|
347 |
+
uri (str): the url path of remote tool
|
348 |
+
method (str): http method: one of [get, post, put, delete]
|
349 |
+
path_info (dict): the spec info of remote tool
|
350 |
+
parameters_views (dict[str, ParametersView]):
|
351 |
+
the dict of parameters views which are the schema of input/output of tool
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
RemoteToolView: the instance of remote tool view
|
355 |
+
"""
|
356 |
+
parameters_ref_uri, returns_ref_uri = None, None
|
357 |
+
parameters, parameters_description = None, None
|
358 |
+
if "requestBody" in path_info:
|
359 |
+
request_ref = path_info["requestBody"]["content"]["application/json"]["schema"]["$ref"]
|
360 |
+
parameters_ref_uri = request_ref.split("/")[-1]
|
361 |
+
assert parameters_ref_uri in parameters_views
|
362 |
+
parameters = parameters_views[parameters_ref_uri]
|
363 |
+
parameters_description = path_info["requestBody"].get("description", None)
|
364 |
+
|
365 |
+
returns, returns_description = None, None
|
366 |
+
if "responses" in path_info:
|
367 |
+
response_ref = list(path_info["responses"].values())[0]["content"]["application/json"]["schema"][
|
368 |
+
"$ref"
|
369 |
+
]
|
370 |
+
returns_ref_uri = response_ref.split("/")[-1]
|
371 |
+
assert returns_ref_uri in parameters_views
|
372 |
+
returns = parameters_views[returns_ref_uri]
|
373 |
+
returns_description = list(path_info["responses"].values())[0].get("description", None)
|
374 |
+
|
375 |
+
return RemoteToolView(
|
376 |
+
name=path_info["operationId"],
|
377 |
+
parameters=parameters,
|
378 |
+
parameters_description=parameters_description,
|
379 |
+
returns=returns,
|
380 |
+
returns_description=returns_description,
|
381 |
+
description=path_info.get("description", path_info.get("summary", None)),
|
382 |
+
method=method,
|
383 |
+
uri=uri,
|
384 |
+
# save ref id info
|
385 |
+
returns_ref_uri=returns_ref_uri,
|
386 |
+
parameters_ref_uri=parameters_ref_uri,
|
387 |
+
)
|
388 |
+
|
389 |
+
def function_call_schema(self):
|
390 |
+
inputs = {
|
391 |
+
"name": self.name,
|
392 |
+
"description": self.description,
|
393 |
+
# TODO(wj-Mcat): read examples from openapi.yaml
|
394 |
+
# "examples": [example.to_dict() for example in self.examples],
|
395 |
+
}
|
396 |
+
if self.parameters is not None:
|
397 |
+
inputs["parameters"] = self.parameters.function_call_schema() # type: ignore
|
398 |
+
else:
|
399 |
+
inputs["parameters"] = {"type": "object", "properties": {}}
|
400 |
+
|
401 |
+
if self.returns is not None:
|
402 |
+
inputs["responses"] = self.returns.function_call_schema() # type: ignore
|
403 |
+
return scrub_dict(inputs) or {}
|
404 |
+
|
405 |
+
|
406 |
+
@dataclass
|
407 |
+
class Endpoint:
|
408 |
+
url: str
|
409 |
+
|
410 |
+
|
411 |
+
@dataclass
|
412 |
+
class EndpointInfo:
|
413 |
+
title: str
|
414 |
+
description: str
|
415 |
+
version: str
|
erniebot-agent/erniebot_agent/tools/tool_manager.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
from typing import Dict, List, final
|
17 |
+
|
18 |
+
from erniebot_agent.tools.base import Tool
|
19 |
+
|
20 |
+
|
21 |
+
@final
|
22 |
+
class ToolManager(object):
|
23 |
+
"""A `ToolManager` instance manages tools for an agent.
|
24 |
+
|
25 |
+
This implementation is based on `ToolsManager` in
|
26 |
+
https://github.com/deepset-ai/haystack/blob/main/haystack/agents/base.py
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, tools: List[Tool]) -> None:
|
30 |
+
super().__init__()
|
31 |
+
self._tools: Dict[str, Tool] = {}
|
32 |
+
for tool in tools:
|
33 |
+
self.add_tool(tool)
|
34 |
+
|
35 |
+
def __getitem__(self, tool_name: str) -> Tool:
|
36 |
+
return self.get_tool(tool_name)
|
37 |
+
|
38 |
+
def add_tool(self, tool: Tool) -> None:
|
39 |
+
tool_name = tool.tool_name
|
40 |
+
if tool_name in self._tools:
|
41 |
+
raise RuntimeError(f"Name {repr(tool_name)} is already registered.")
|
42 |
+
self._tools[tool_name] = tool
|
43 |
+
|
44 |
+
def remove_tool(self, tool: Tool) -> None:
|
45 |
+
tool_name = tool.tool_name
|
46 |
+
if tool_name not in self._tools:
|
47 |
+
raise RuntimeError(f"Name {repr(tool_name)} is not registered.")
|
48 |
+
if self._tools[tool_name] is not tool:
|
49 |
+
raise RuntimeError(f"The tool with the registered name {repr(tool_name)} is not the given tool.")
|
50 |
+
self._tools.pop(tool_name)
|
51 |
+
|
52 |
+
def get_tool(self, tool_name: str) -> Tool:
|
53 |
+
if tool_name not in self._tools:
|
54 |
+
raise RuntimeError(f"Name {repr(tool_name)} is not registered.")
|
55 |
+
return self._tools[tool_name]
|
56 |
+
|
57 |
+
def get_tools(self) -> List[Tool]:
|
58 |
+
return list(self._tools.values())
|
59 |
+
|
60 |
+
def get_tool_names(self) -> str:
|
61 |
+
return ", ".join(self._tools.keys())
|
62 |
+
|
63 |
+
def get_tool_names_with_descriptions(self) -> str:
|
64 |
+
return "\n".join(
|
65 |
+
f"{name}:{json.dumps(tool.function_call_schema())}" for name, tool in self._tools.items()
|
66 |
+
)
|
67 |
+
|
68 |
+
def get_tool_schemas(self):
|
69 |
+
return [tool.function_call_schema() for tool in self._tools.values()]
|
erniebot-agent/erniebot_agent/utils/__init__.py
ADDED
File without changes
|