Spaces:
Runtime error
Runtime error
fix bug
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __pycache__/cmd_perform.cpython-38.pyc +0 -0
- __pycache__/create_sop.cpython-38.pyc +0 -0
- __pycache__/gradio_base.cpython-38.pyc +0 -0
- __pycache__/gradio_config.cpython-38.pyc +0 -0
- agents/Action/__init__.py +1 -0
- agents/Action/__pycache__/__init__.cpython-38.pyc +0 -0
- agents/Action/__pycache__/base_action.cpython-38.pyc +0 -0
- agents/Action/base_action.py +48 -0
- agents/Agent/Agent.py +243 -0
- agents/Agent/__init__.py +1 -0
- agents/Agent/__pycache__/Agent.cpython-38.pyc +0 -0
- agents/Agent/__pycache__/__init__.cpython-38.pyc +0 -0
- agents/Component/ExtraComponent.py +128 -0
- agents/Component/PromptComponent.py +133 -0
- agents/Component/ToolComponent.py +887 -0
- agents/Component/__init__.py +3 -0
- agents/Component/__pycache__/ExtraComponent.cpython-38.pyc +0 -0
- agents/Component/__pycache__/PromptComponent.cpython-38.pyc +0 -0
- agents/Component/__pycache__/ToolComponent.cpython-38.pyc +0 -0
- agents/Component/__pycache__/__init__.cpython-38.pyc +0 -0
- agents/Environment/__init__.py +1 -0
- agents/Environment/__pycache__/__init__.cpython-38.pyc +0 -0
- agents/Environment/__pycache__/base_environment.cpython-38.pyc +0 -0
- agents/Environment/base_environment.py +167 -0
- agents/LLM/__init__.py +0 -0
- agents/LLM/__pycache__/__init__.cpython-38.pyc +0 -0
- agents/LLM/__pycache__/base_LLM.cpython-38.pyc +0 -0
- agents/LLM/base_LLM.py +133 -0
- agents/Memory/__init__.py +1 -0
- agents/Memory/__pycache__/__init__.cpython-38.pyc +0 -0
- agents/Memory/__pycache__/base_Memory.cpython-38.pyc +0 -0
- agents/Memory/base_Memory.py +32 -0
- agents/Prompt/__init__.py +1 -0
- agents/Prompt/__pycache__/__init__.cpython-38.pyc +0 -0
- agents/Prompt/__pycache__/base_Prompts.cpython-38.pyc +0 -0
- agents/Prompt/base_Prompts.py +83 -0
- agents/SOP.py +296 -0
- agents/State.py +142 -0
- agents/__init__.py +4 -0
- agents/__pycache__/SOP.cpython-38.pyc +0 -0
- agents/__pycache__/State.cpython-38.pyc +0 -0
- agents/__pycache__/__init__.cpython-38.pyc +0 -0
- agents/__pycache__/evolve.cpython-38.pyc +0 -0
- agents/__pycache__/utils.cpython-38.pyc +0 -0
- agents/evolve.py +17 -0
- agents/template.py +111 -0
- agents/utils.py +480 -0
- app.py +9 -6
- create_sop.py +1 -2
- gradio_backend.py +10 -8
__pycache__/cmd_perform.cpython-38.pyc
ADDED
Binary file (1.78 kB). View file
|
|
__pycache__/create_sop.cpython-38.pyc
ADDED
Binary file (8.27 kB). View file
|
|
__pycache__/gradio_base.cpython-38.pyc
ADDED
Binary file (16.4 kB). View file
|
|
__pycache__/gradio_config.cpython-38.pyc
ADDED
Binary file (12.4 kB). View file
|
|
agents/Action/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base_action import Action
|
agents/Action/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (158 Bytes). View file
|
|
agents/Action/__pycache__/base_action.cpython-38.pyc
ADDED
Binary file (1.32 kB). View file
|
|
agents/Action/base_action.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Memory import Memory
|
2 |
+
class Action:
|
3 |
+
"""
|
4 |
+
The basic action unit of agent
|
5 |
+
"""
|
6 |
+
def __init__(self,**kwargs):
|
7 |
+
self.response = None
|
8 |
+
self.is_user = False
|
9 |
+
self.res_dict = {}
|
10 |
+
self.name = ""
|
11 |
+
self.role = ""
|
12 |
+
for key,value in kwargs.items():
|
13 |
+
setattr(self,key,value)
|
14 |
+
|
15 |
+
|
16 |
+
def process(self):
|
17 |
+
"""
|
18 |
+
processing action
|
19 |
+
Rerutn : memory(Memory)
|
20 |
+
"""
|
21 |
+
response = self.response
|
22 |
+
send_name = self.name
|
23 |
+
send_role = self.role
|
24 |
+
all = ""
|
25 |
+
for res in response:
|
26 |
+
all += res
|
27 |
+
parse = f"{send_name}:"
|
28 |
+
|
29 |
+
# 将里面对话的第三人称删了
|
30 |
+
# The third person in the dialogue was deleted.
|
31 |
+
while parse in all:
|
32 |
+
index = all.index(parse) + len(parse)
|
33 |
+
all = all[index:]
|
34 |
+
|
35 |
+
if not self.is_user:
|
36 |
+
print(f"{send_name}({send_role}):{all}")
|
37 |
+
# for software
|
38 |
+
if "<title>" in all:
|
39 |
+
title = extract(all,"title")
|
40 |
+
python = extract(all,"python")
|
41 |
+
os.makedirs("output_code", exist_ok=True)
|
42 |
+
file_name = "output_code/" + title
|
43 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
44 |
+
f.write(python)
|
45 |
+
memory = Memory(send_role, send_name, all)
|
46 |
+
return memory
|
47 |
+
|
48 |
+
|
agents/Agent/Agent.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The AIWaves Inc. team.
|
3 |
+
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""LLM autonoumous agent"""
|
17 |
+
from LLM.base_LLM import *
|
18 |
+
from Component import *
|
19 |
+
from Action import Action
|
20 |
+
from Prompt import *
|
21 |
+
|
22 |
+
headers = {
|
23 |
+
"Content-Type": "text/event-stream",
|
24 |
+
"Cache-Control": "no-cache",
|
25 |
+
"X-Accel-Buffering": "no",
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class Agent:
|
32 |
+
"""
|
33 |
+
Auto agent, input the JSON of SOP.
|
34 |
+
"""
|
35 |
+
|
36 |
+
# Agent should have args: agents,states
|
37 |
+
def __init__(self, name, agent_state_roles, **kwargs) -> None:
|
38 |
+
self.state_roles = agent_state_roles
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
self.style = kwargs["style"]
|
42 |
+
self.LLMs = kwargs["LLMs"]
|
43 |
+
self.LLM = None
|
44 |
+
self.is_user = kwargs["is_user"]
|
45 |
+
self.begins = kwargs["begins"] if "begins" in kwargs else False
|
46 |
+
self.current_role = ""
|
47 |
+
self.long_term_memory = []
|
48 |
+
self.short_term_memory = ""
|
49 |
+
self.current_state = None
|
50 |
+
self.first_speak = True
|
51 |
+
self.environment = None
|
52 |
+
|
53 |
+
|
54 |
+
@classmethod
|
55 |
+
def from_config(cls, config_path):
|
56 |
+
"""
|
57 |
+
Initialize agents based on json file
|
58 |
+
Return:
|
59 |
+
agents(dict) : key:agent_name;value:class(Agent)
|
60 |
+
names_to_roles(dict) : key:state_name value:(dict; (key:agent_name ; value:agent_role))
|
61 |
+
roles_to_names(dict) : key:state_name value:(dict; (key:agent_role ; value:agent_name))
|
62 |
+
"""
|
63 |
+
with open(config_path) as f:
|
64 |
+
config = json.load(f)
|
65 |
+
|
66 |
+
roles_to_names = {}
|
67 |
+
names_to_roles = {}
|
68 |
+
agents = {}
|
69 |
+
user_names = json.loads(os.environ["User_Names"]) if "User_Names" in os.environ else []
|
70 |
+
for agent_name, agent_dict in config["agents"].items():
|
71 |
+
agent_state_roles = {}
|
72 |
+
agent_LLMs = {}
|
73 |
+
agent_begins = {}
|
74 |
+
for state_name, agent_role in agent_dict["roles"].items():
|
75 |
+
|
76 |
+
agent_begins[state_name] = {}
|
77 |
+
|
78 |
+
if state_name not in roles_to_names:
|
79 |
+
roles_to_names[state_name] = {}
|
80 |
+
if state_name not in names_to_roles:
|
81 |
+
names_to_roles[state_name] = {}
|
82 |
+
roles_to_names[state_name][agent_role] = agent_name
|
83 |
+
names_to_roles[state_name][agent_name] = agent_role
|
84 |
+
agent_state_roles[state_name] = agent_role
|
85 |
+
current_state = config["states"][state_name]
|
86 |
+
|
87 |
+
current_state_begin_role = current_state["begin_role"] if "begin_role" in current_state else current_state["roles"][0]
|
88 |
+
agent_begins[state_name]["is_begin"] = current_state_begin_role==agent_role if "begin_role" in current_state else False
|
89 |
+
agent_begins[state_name]["begin_query"] = current_state["begin_query"] if "begin_query" in current_state else " "
|
90 |
+
agent_LLMs[state_name] = init_LLM(f"logs/{agent_name}",**current_state["agent_states"][agent_role])
|
91 |
+
agents[agent_name] = cls(
|
92 |
+
agent_name,
|
93 |
+
agent_state_roles,
|
94 |
+
LLMs=agent_LLMs,
|
95 |
+
is_user=agent_name in user_names,
|
96 |
+
style = agent_dict["style"],
|
97 |
+
begins = agent_begins
|
98 |
+
)
|
99 |
+
assert len(config["agents"].keys()) != 2 or (roles_to_names[config["root"]][config["states"][config["root"]]["begin_role"]] not in user_names and "begin_query" in config["states"][config["root"]]),"In a single-agent scenario, there must be an opening statement and it must be the agent"
|
100 |
+
return agents, roles_to_names, names_to_roles
|
101 |
+
|
102 |
+
def step(self, current_state,input=""):
|
103 |
+
"""
|
104 |
+
return actions by current state and environment
|
105 |
+
Return: action(Action)
|
106 |
+
"""
|
107 |
+
|
108 |
+
current_state.chat_nums +=1
|
109 |
+
state_begin = current_state.is_begin
|
110 |
+
agent_begin = self.begins[current_state.name]["is_begin"]
|
111 |
+
self.begins[current_state.name]["is_begin"] = False
|
112 |
+
current_state.is_begin = False
|
113 |
+
environment = self.environment
|
114 |
+
|
115 |
+
self.current_state = current_state
|
116 |
+
# 先根据当前环境更新信息
|
117 |
+
# First update the information according to the current environment
|
118 |
+
|
119 |
+
response = " "
|
120 |
+
res_dict = {}
|
121 |
+
|
122 |
+
if self.is_user:
|
123 |
+
response = f"{self.name}:{input}"
|
124 |
+
else:
|
125 |
+
if len(environment.shared_memory["long_term_memory"])>0:
|
126 |
+
current_history = self.observe()
|
127 |
+
self.long_term_memory.append(current_history)
|
128 |
+
if agent_begin:
|
129 |
+
response = (char for char in self.begins[current_state.name]["begin_query"])
|
130 |
+
else:
|
131 |
+
response,res_dict = self.act()
|
132 |
+
|
133 |
+
|
134 |
+
action_dict = {
|
135 |
+
"response": response,
|
136 |
+
"res_dict": res_dict,
|
137 |
+
"role": self.state_roles[current_state.name],
|
138 |
+
"name": self.name,
|
139 |
+
"state_begin" : state_begin,
|
140 |
+
"agent_begin" : agent_begin,
|
141 |
+
"is_user" : self.is_user
|
142 |
+
}
|
143 |
+
return Action(**action_dict)
|
144 |
+
|
145 |
+
def act(self):
|
146 |
+
"""
|
147 |
+
return actions by the current state
|
148 |
+
"""
|
149 |
+
current_state = self.current_state
|
150 |
+
chat_history = self.long_term_memory
|
151 |
+
current_LLM = self.LLMs[current_state.name]
|
152 |
+
|
153 |
+
system_prompt, last_prompt, res_dict = self.compile()
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
response = current_LLM.get_response(
|
158 |
+
chat_history, system_prompt, last_prompt, stream=True
|
159 |
+
)
|
160 |
+
return response,res_dict
|
161 |
+
|
162 |
+
def update_memory(self, memory):
|
163 |
+
self.long_term_memory.append(
|
164 |
+
{"role": "assistant", "content": memory.content}
|
165 |
+
)
|
166 |
+
|
167 |
+
MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
|
168 |
+
environment = self.environment
|
169 |
+
current_chat_history_idx = environment.current_chat_history_idx if environment.environment_type == "competive" else 0
|
170 |
+
|
171 |
+
current_long_term_memory = environment.shared_memory["long_term_memory"][current_chat_history_idx:]
|
172 |
+
last_conversation_idx = environment._get_agent_last_conversation_idx(self,current_long_term_memory)
|
173 |
+
if len(current_long_term_memory)-last_conversation_idx >= MAX_CHAT_HISTORY:
|
174 |
+
current_state = self.current_state
|
175 |
+
current_role = self.state_roles[current_state.name]
|
176 |
+
current_component_dict = current_state.components[current_role]
|
177 |
+
|
178 |
+
# get chat history from new conversation
|
179 |
+
conversations = environment._get_agent_new_memory(self,current_long_term_memory)
|
180 |
+
|
181 |
+
# get summary
|
182 |
+
summary_prompt = (
|
183 |
+
current_state.summary_prompt[current_role]
|
184 |
+
if current_state.summary_prompt
|
185 |
+
else f"""your name is {self.name},your role is{current_component_dict["style"].role},your task is {current_component_dict["task"].task}.\n"""
|
186 |
+
)
|
187 |
+
summary_prompt =eval(Agent_summary_system_prompt)
|
188 |
+
summary = self.LLMs[current_state.name].get_response(None, summary_prompt,stream = False)
|
189 |
+
self.short_term_memory = summary
|
190 |
+
|
191 |
+
|
192 |
+
def compile(self):
|
193 |
+
"""
|
194 |
+
get prompt from state depend on your role
|
195 |
+
Return:
|
196 |
+
system_prompt:system_prompt for agents's LLM
|
197 |
+
last_prompt:last_prompt for agents's LLM
|
198 |
+
res_dict(dict): Other return from tool component.For example: search engine results
|
199 |
+
"""
|
200 |
+
current_state = self.current_state
|
201 |
+
self.current_roles = self.state_roles[current_state.name]
|
202 |
+
current_state_name = current_state.name
|
203 |
+
self.LLM = self.LLMs[current_state_name]
|
204 |
+
components = current_state.components[self.state_roles[current_state_name]]
|
205 |
+
|
206 |
+
system_prompt = self.current_state.environment_prompt
|
207 |
+
last_prompt = ""
|
208 |
+
|
209 |
+
res_dict = {}
|
210 |
+
for component in components.values():
|
211 |
+
if isinstance(component, (OutputComponent, LastComponent)):
|
212 |
+
last_prompt = last_prompt + "\n" + component.get_prompt(self)
|
213 |
+
elif isinstance(component, PromptComponent):
|
214 |
+
system_prompt = (
|
215 |
+
system_prompt + "\n" + component.get_prompt(self)
|
216 |
+
)
|
217 |
+
elif isinstance(component, ToolComponent):
|
218 |
+
response = component.func(self)
|
219 |
+
if "prompt" in response and response["prompt"]:
|
220 |
+
last_prompt = last_prompt + "\n" + response["prompt"]
|
221 |
+
res_dict.update(response)
|
222 |
+
|
223 |
+
name = self.name
|
224 |
+
query = self.environment.shared_memory["long_term_memory"][-1]
|
225 |
+
last_prompt = eval(Agent_last_prompt)
|
226 |
+
system_prompt = eval(Agent_system_prompt)
|
227 |
+
return system_prompt, last_prompt, res_dict
|
228 |
+
|
229 |
+
|
230 |
+
def observe(self):
|
231 |
+
"""
|
232 |
+
Update one's own memory according to the current environment, including: updating short-term memory; updating long-term memory
|
233 |
+
"""
|
234 |
+
return self.environment._observe(self)
|
235 |
+
|
236 |
+
|
237 |
+
def generate_sop(self):
|
238 |
+
pass
|
239 |
+
|
240 |
+
def reflection(self):
|
241 |
+
pass
|
242 |
+
|
243 |
+
|
agents/Agent/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .Agent import Agent
|
agents/Agent/__pycache__/Agent.cpython-38.pyc
ADDED
Binary file (6.2 kB). View file
|
|
agents/Agent/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (141 Bytes). View file
|
|
agents/Component/ExtraComponent.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ToolComponent import ToolComponent
|
2 |
+
import json
|
3 |
+
from utils import flatten_dict,get_embedding,matching_category,search_with_api,limit_keys,limit_values
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
class CategoryRequirementsComponent(ToolComponent):
|
8 |
+
def __init__(self, information_path):
|
9 |
+
super().__init__()
|
10 |
+
self.information_dataset = []
|
11 |
+
self.leaf_name = []
|
12 |
+
for toy_path in information_path:
|
13 |
+
with open(toy_path, encoding="utf-8") as json_file:
|
14 |
+
data = json.load(json_file)
|
15 |
+
for d in data:
|
16 |
+
if "/" in d["cat_leaf_name"]:
|
17 |
+
leaf_names = d["cat_leaf_name"].split("/") + [d["cat_leaf_name"]]
|
18 |
+
else:
|
19 |
+
leaf_names = [d["cat_leaf_name"]]
|
20 |
+
for name in leaf_names:
|
21 |
+
self.leaf_name.append(name)
|
22 |
+
new_d = d.copy()
|
23 |
+
new_d["cat_leaf_name"] = name
|
24 |
+
new_d["information"] = flatten_dict(new_d["information"])
|
25 |
+
self.information_dataset.append(new_d)
|
26 |
+
|
27 |
+
self.target_embbeding = get_embedding(
|
28 |
+
self.leaf_name
|
29 |
+
)
|
30 |
+
|
31 |
+
def search_information(self, category, information_dataset):
|
32 |
+
knowledge = {}
|
33 |
+
for d in information_dataset:
|
34 |
+
if category == d["cat_leaf_name"]:
|
35 |
+
knowledge = d["information"]
|
36 |
+
knowledge = {
|
37 |
+
key: value
|
38 |
+
for key, value in knowledge.items()
|
39 |
+
if (value and key != "相关分类")
|
40 |
+
}
|
41 |
+
break
|
42 |
+
return knowledge
|
43 |
+
|
44 |
+
def func(self, agent):
|
45 |
+
prompt = ""
|
46 |
+
messages = agent.long_term_memory
|
47 |
+
outputdict = {}
|
48 |
+
functions = [
|
49 |
+
{
|
50 |
+
"name": "search_information",
|
51 |
+
"description": "根据用户所需要购买商品的种类跟用户的需求去寻找用户所需要的商品",
|
52 |
+
"parameters": {
|
53 |
+
"type": "object",
|
54 |
+
"properties": {
|
55 |
+
"category": {
|
56 |
+
"type": "string",
|
57 |
+
"description": "用户现在所需要的商品类别,比如纸尿布,笔记本电脑等,注意,只能有一个",
|
58 |
+
},
|
59 |
+
"requirements": {
|
60 |
+
"type": "string",
|
61 |
+
"description": "用户现在的需求,比如说便宜,安踏品牌等等,可以有多个需求,中间以“ ”分隔",
|
62 |
+
},
|
63 |
+
},
|
64 |
+
"required": ["category", "requirements"],
|
65 |
+
},
|
66 |
+
}
|
67 |
+
]
|
68 |
+
|
69 |
+
response = agent.LLM.get_response(
|
70 |
+
messages,
|
71 |
+
None,
|
72 |
+
None,
|
73 |
+
functions=functions,
|
74 |
+
stream=False,
|
75 |
+
function_call={"name": "search_information"},
|
76 |
+
)
|
77 |
+
response_message = json.loads(response["function_call"]["arguments"])
|
78 |
+
category = (
|
79 |
+
response_message["category"] if response_message["category"] else None
|
80 |
+
)
|
81 |
+
requirements = (
|
82 |
+
response_message["requirements"]
|
83 |
+
if response_message["requirements"]
|
84 |
+
else category
|
85 |
+
)
|
86 |
+
if not (category or requirements):
|
87 |
+
return {}
|
88 |
+
|
89 |
+
topk_result = matching_category(
|
90 |
+
category, self.leaf_name, None, self.target_embbeding, top_k=3
|
91 |
+
)
|
92 |
+
|
93 |
+
top1_score = topk_result[1][0]
|
94 |
+
request_items, top_category = search_with_api(requirements, category)
|
95 |
+
|
96 |
+
|
97 |
+
MIN_CATEGORY_SIM = eval(os.environ["MIN_CATEGORY_SIM"]
|
98 |
+
) if "MIN_CATEGORY_SIM" in os.environ else 0.7
|
99 |
+
|
100 |
+
if top1_score > MIN_CATEGORY_SIM:
|
101 |
+
agent.environment.shared_memory["category"] = topk_result[0][0]
|
102 |
+
category = topk_result[0][0]
|
103 |
+
information = self.search_information(
|
104 |
+
topk_result[0][0], self.information_dataset
|
105 |
+
)
|
106 |
+
information = limit_keys(information, 3)
|
107 |
+
information = limit_values(information, 2)
|
108 |
+
prompt += f"""你需要知道的是:用户目前选择的商品是{category},该商品信息为{information}。你需要根据这些商品信息来详细介绍商品,比如详细介绍商品有哪些品牌,有哪些分类等等,并且询问用户是否有更多的需求。"""
|
109 |
+
if category in top_category:
|
110 |
+
top_category.remove(category)
|
111 |
+
|
112 |
+
recommend = "\n经过搜索后,推荐商品如下:\n"
|
113 |
+
prompt += "筛选出的商品如下:\n"
|
114 |
+
|
115 |
+
for i, request_item in enumerate(request_items):
|
116 |
+
|
117 |
+
itemTitle = request_item["itemTitle"]
|
118 |
+
itemPrice = request_item["itemPrice"]
|
119 |
+
itemPicUrl = request_item["itemPicUrl"]
|
120 |
+
recommend += f"[{i}.商品名称:{itemTitle},商品价格:{float(itemPrice)/100}]({itemPicUrl})\n"
|
121 |
+
prompt += f"[{i}.商品名称:{itemTitle},商品价格:{float(itemPrice)/100}]\n"
|
122 |
+
outputdict["recommend"] = recommend
|
123 |
+
print(recommend)
|
124 |
+
else:
|
125 |
+
prompt += f"""你需要知道的是:用户目前选择的商品是{category},而我们店里没有这类商品,但是我们店里有一些近似商品,如{top_category},{topk_result[0][0]},你需要对这些近似商品进行介绍,并引导用户购买"""
|
126 |
+
outputdict["prompt"] = prompt
|
127 |
+
return outputdict
|
128 |
+
|
agents/Component/PromptComponent.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class PromptComponent:
|
5 |
+
def __init__(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
@abstractmethod
|
9 |
+
def get_prompt(self, agent):
|
10 |
+
pass
|
11 |
+
|
12 |
+
class TaskComponent(PromptComponent):
|
13 |
+
def __init__(self, task):
|
14 |
+
super().__init__()
|
15 |
+
self.task = task
|
16 |
+
|
17 |
+
def get_prompt(self, agent):
|
18 |
+
return f"""The task you need to execute is: <task>{self.task}</task>.\n"""
|
19 |
+
|
20 |
+
|
21 |
+
class OutputComponent(PromptComponent):
|
22 |
+
def __init__(self, output):
|
23 |
+
super().__init__()
|
24 |
+
self.output = output
|
25 |
+
|
26 |
+
def get_prompt(self, agent):
|
27 |
+
return f"""Please contact the above to extract <{self.output}> and </{self.output}>, \
|
28 |
+
do not perform additional output, please output in strict accordance with the above format!\n"""
|
29 |
+
|
30 |
+
|
31 |
+
class SystemComponent(PromptComponent):
|
32 |
+
def __init__(self,system_prompt):
|
33 |
+
super().__init__()
|
34 |
+
self.system_prompt = system_prompt
|
35 |
+
|
36 |
+
def get_prompt(self, agent):
|
37 |
+
return self.system_prompt
|
38 |
+
|
39 |
+
class LastComponent(PromptComponent):
|
40 |
+
def __init__(self, last_prompt):
|
41 |
+
super().__init__()
|
42 |
+
self.last_prompt = last_prompt
|
43 |
+
|
44 |
+
def get_prompt(self, agent):
|
45 |
+
return self.last_prompt
|
46 |
+
|
47 |
+
|
48 |
+
class StyleComponent(PromptComponent):
|
49 |
+
"""
|
50 |
+
角色、风格组件
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, role):
|
54 |
+
super().__init__()
|
55 |
+
self.role = role
|
56 |
+
|
57 |
+
def get_prompt(self, agent):
|
58 |
+
name = agent.name
|
59 |
+
style = agent.style
|
60 |
+
return f"""Now your role is:\n<role>{self.role}</role>, your name is:\n<name>{name}</name>. \
|
61 |
+
You need to follow the output style:\n<style>{style}</style>.\n"""
|
62 |
+
|
63 |
+
|
64 |
+
class RuleComponent(PromptComponent):
|
65 |
+
def __init__(self, rule):
|
66 |
+
super().__init__()
|
67 |
+
self.rule = rule
|
68 |
+
|
69 |
+
def get_prompt(self, agent):
|
70 |
+
return f"""The rule you need to follow is:\n<rule>{self.rule}</rule>.\n"""
|
71 |
+
|
72 |
+
|
73 |
+
class DemonstrationComponent(PromptComponent):
|
74 |
+
"""
|
75 |
+
input a list,the example of answer.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, demonstrations):
|
79 |
+
super().__init__()
|
80 |
+
self.demonstrations = demonstrations
|
81 |
+
|
82 |
+
def add_demonstration(self, demonstration):
|
83 |
+
self.demonstrations.append(demonstration)
|
84 |
+
|
85 |
+
def get_prompt(self, agent):
|
86 |
+
prompt = "Here are demonstrations you can refer to:\n<demonstrations>"
|
87 |
+
for demonstration in self.demonstrations:
|
88 |
+
prompt += "\n" + demonstration
|
89 |
+
prompt += "</demonstrations>\n"
|
90 |
+
return prompt
|
91 |
+
|
92 |
+
|
93 |
+
class CoTComponent(PromptComponent):
|
94 |
+
"""
|
95 |
+
input a list,the example of answer.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, demonstrations):
|
99 |
+
super().__init__()
|
100 |
+
self.demonstrations = demonstrations
|
101 |
+
|
102 |
+
def add_demonstration(self, demonstration):
|
103 |
+
self.demonstrations.append(demonstration)
|
104 |
+
|
105 |
+
def get_prompt(self, agent):
|
106 |
+
prompt = "You need to think in detail before outputting, the thinking case is as follows:\n<demonstrations>"
|
107 |
+
for demonstration in self.demonstrations:
|
108 |
+
prompt += "\n" + demonstration
|
109 |
+
prompt += "</demonstrations>\n"
|
110 |
+
return prompt
|
111 |
+
|
112 |
+
|
113 |
+
class CustomizeComponent(PromptComponent):
|
114 |
+
"""
|
115 |
+
Custom template
|
116 |
+
template(str) : example: "i am {}"
|
117 |
+
keywords(list) : example : ["name"]
|
118 |
+
example : agent.environment.shared_memory["name"] = "Lilong"
|
119 |
+
the component will get the keyword attribute from the environment, and then add it to the template.
|
120 |
+
Return : "i am Lilong"
|
121 |
+
"""
|
122 |
+
def __init__(self, template, keywords) -> None:
|
123 |
+
super().__init__()
|
124 |
+
self.template = template
|
125 |
+
self.keywords = keywords
|
126 |
+
|
127 |
+
def get_prompt(self, agent):
|
128 |
+
template_keyword = {}
|
129 |
+
for keyword in self.keywords:
|
130 |
+
|
131 |
+
current_keyword = agent.environment.shared_memory[keyword]
|
132 |
+
template_keyword[keyword] = current_keyword
|
133 |
+
return self.template.format(**template_keyword)
|
agents/Component/ToolComponent.py
ADDED
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import uuid
|
3 |
+
from text2vec import semantic_search
|
4 |
+
from utils import (
|
5 |
+
get_relevant_history,
|
6 |
+
load_knowledge_base_qa,
|
7 |
+
load_knowledge_base_UnstructuredFile,
|
8 |
+
get_embedding,
|
9 |
+
extract,
|
10 |
+
)
|
11 |
+
import json
|
12 |
+
from typing import Dict, List
|
13 |
+
import os
|
14 |
+
from googleapiclient.discovery import build
|
15 |
+
import requests
|
16 |
+
from selenium import webdriver
|
17 |
+
from selenium.webdriver.common.by import By
|
18 |
+
from selenium.webdriver.support.ui import WebDriverWait
|
19 |
+
from selenium.webdriver.support import expected_conditions as EC
|
20 |
+
from bs4 import BeautifulSoup
|
21 |
+
import base64
|
22 |
+
import re
|
23 |
+
from datetime import datetime, timedelta
|
24 |
+
from typing import Tuple, List, Any, Dict
|
25 |
+
from email.mime.text import MIMEText
|
26 |
+
from email.mime.multipart import MIMEMultipart
|
27 |
+
from google.auth.transport.requests import Request
|
28 |
+
from google.oauth2.credentials import Credentials
|
29 |
+
from google_auth_oauthlib.flow import InstalledAppFlow
|
30 |
+
from googleapiclient.discovery import build
|
31 |
+
from googleapiclient.errors import HttpError
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
class ToolComponent:
|
35 |
+
def __init__(self):
|
36 |
+
pass
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
def func(self):
|
40 |
+
pass
|
41 |
+
|
42 |
+
class KnowledgeBaseComponent(ToolComponent):
|
43 |
+
"""
|
44 |
+
Inject knowledge base
|
45 |
+
top_k : Top_k with the highest matching degree
|
46 |
+
type : "QA" or others
|
47 |
+
knowledge_base(json_path) : knowledge_base_path
|
48 |
+
"""
|
49 |
+
def __init__(self, top_k, type, knowledge_base):
|
50 |
+
super().__init__()
|
51 |
+
self.top_k = top_k
|
52 |
+
self.type = type
|
53 |
+
self.knowledge_base = knowledge_base
|
54 |
+
|
55 |
+
if self.type == "QA":
|
56 |
+
(
|
57 |
+
self.kb_embeddings,
|
58 |
+
self.kb_questions,
|
59 |
+
self.kb_answers,
|
60 |
+
self.kb_chunks,
|
61 |
+
) = load_knowledge_base_qa(self.knowledge_base)
|
62 |
+
else:
|
63 |
+
self.kb_embeddings, self.kb_chunks = load_knowledge_base_UnstructuredFile(
|
64 |
+
self.knowledge_base
|
65 |
+
)
|
66 |
+
|
67 |
+
def func(self, agent):
|
68 |
+
query = (
|
69 |
+
agent.long_term_memory[-1]["content"]
|
70 |
+
if len(agent.long_term_memory) > 0
|
71 |
+
else ""
|
72 |
+
)
|
73 |
+
knowledge = ""
|
74 |
+
query = extract(query, "query")
|
75 |
+
query_embedding = get_embedding(query)
|
76 |
+
hits = semantic_search(query_embedding, self.kb_embeddings, top_k=50)
|
77 |
+
hits = hits[0]
|
78 |
+
temp = []
|
79 |
+
if self.type == "QA":
|
80 |
+
for hit in hits:
|
81 |
+
matching_idx = hit["corpus_id"]
|
82 |
+
if self.kb_chunks[matching_idx] in temp:
|
83 |
+
pass
|
84 |
+
else:
|
85 |
+
knowledge = (
|
86 |
+
knowledge
|
87 |
+
+ f"question:{self.kb_questions[matching_idx]},answer:{self.kb_answers[matching_idx]}\n\n"
|
88 |
+
)
|
89 |
+
temp.append(self.kb_answers[matching_idx])
|
90 |
+
if len(temp) == 1:
|
91 |
+
break
|
92 |
+
print(hits[0]["score"])
|
93 |
+
score = hits[0]["score"]
|
94 |
+
if score < 0.5:
|
95 |
+
return {"prompt": "No matching knowledge base"}
|
96 |
+
else:
|
97 |
+
return {"prompt": "The relevant content is: " + knowledge + "\n"}
|
98 |
+
else:
|
99 |
+
for hit in hits:
|
100 |
+
matching_idx = hit["corpus_id"]
|
101 |
+
if self.kb_chunks[matching_idx] in temp:
|
102 |
+
pass
|
103 |
+
else:
|
104 |
+
knowledge = knowledge + f"{self.kb_answers[matching_idx]}\n\n"
|
105 |
+
temp.append(self.kb_answers[matching_idx])
|
106 |
+
if len(temp) == self.top_k:
|
107 |
+
break
|
108 |
+
print(hits[0]["score"])
|
109 |
+
score = hits[0]["score"]
|
110 |
+
if score < 0.5:
|
111 |
+
return {"prompt": "No matching knowledge base"}
|
112 |
+
else:
|
113 |
+
print(knowledge)
|
114 |
+
return {"prompt": "The relevant content is: " + knowledge + "\n"}
|
115 |
+
|
116 |
+
|
117 |
+
class StaticComponent(ToolComponent):
|
118 |
+
"Return static response"
|
119 |
+
def __init__(self, output):
|
120 |
+
super().__init__()
|
121 |
+
self.output = output
|
122 |
+
|
123 |
+
def func(self, agent):
|
124 |
+
outputdict = {"response": self.output}
|
125 |
+
return outputdict
|
126 |
+
|
127 |
+
|
128 |
+
class ExtractComponent(ToolComponent):
|
129 |
+
"""
|
130 |
+
Extract keywords based on the current scene and store them in the environment
|
131 |
+
extract_words(list) : Keywords to be extracted
|
132 |
+
system_prompt & last_prompt : Prompt to extract keywords
|
133 |
+
"""
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
extract_words,
|
137 |
+
system_prompt,
|
138 |
+
last_prompt=None,
|
139 |
+
):
|
140 |
+
super().__init__()
|
141 |
+
self.extract_words = extract_words
|
142 |
+
self.system_prompt = system_prompt
|
143 |
+
self.default_prompt = (
|
144 |
+
"Please strictly adhere to the following format for outputting:\n"
|
145 |
+
)
|
146 |
+
for extract_word in extract_words:
|
147 |
+
self.default_prompt += (
|
148 |
+
f"<{extract_word}> the content you need to extract </{extract_word}>"
|
149 |
+
)
|
150 |
+
self.last_prompt = last_prompt if last_prompt else self.default_prompt
|
151 |
+
|
152 |
+
def func(self, agent):
|
153 |
+
response = agent.LLM.get_response(
|
154 |
+
agent.long_term_memory,
|
155 |
+
self.system_prompt,
|
156 |
+
self.last_prompt,
|
157 |
+
stream=False,
|
158 |
+
)
|
159 |
+
for extract_word in self.extract_words:
|
160 |
+
key = extract(response, extract_word)
|
161 |
+
key = key if key else response
|
162 |
+
agent.environment.shared_memory[extract_word] = key
|
163 |
+
|
164 |
+
return {}
|
165 |
+
|
166 |
+
|
167 |
+
"""Search sources: chatgpt/search engines/specific search sources/can even be multimodal (if it comes to clothing)"""
|
168 |
+
|
169 |
+
|
170 |
+
class WebSearchComponent(ToolComponent):
|
171 |
+
"""search engines"""
|
172 |
+
|
173 |
+
__ENGINE_NAME__: List = ["google", "bing"]
|
174 |
+
|
175 |
+
def __init__(self, engine_name: str, api: Dict):
|
176 |
+
"""
|
177 |
+
:param engine_name: The name of the search engine used
|
178 |
+
:param api: Pass in a dictionary, such as {"bing":"key1", "google":"key2", ...}, of course each value can also be a list, or more complicated
|
179 |
+
"""
|
180 |
+
super(WebSearchComponent, self).__init__()
|
181 |
+
"""Determine whether the key and engine_name of the api are legal"""
|
182 |
+
|
183 |
+
assert engine_name in WebSearchComponent.__ENGINE_NAME__
|
184 |
+
for api_name in api:
|
185 |
+
assert api_name in WebSearchComponent.__ENGINE_NAME__
|
186 |
+
|
187 |
+
self.api = api
|
188 |
+
self.engine_name = engine_name
|
189 |
+
|
190 |
+
self.search: Dict = {"bing": self._bing_search, "google": self._google_search}
|
191 |
+
|
192 |
+
def _bing_search(self, query: str, **kwargs):
|
193 |
+
"""Initialize search hyperparameters"""
|
194 |
+
subscription_key = self.api["bing"]
|
195 |
+
search_url = "https://api.bing.microsoft.com/v7.0/search"
|
196 |
+
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
|
197 |
+
params = {
|
198 |
+
"q": query,
|
199 |
+
"textDecorations": True,
|
200 |
+
"textFormat": "HTML",
|
201 |
+
"count": 10,
|
202 |
+
}
|
203 |
+
"""start searching"""
|
204 |
+
response = requests.get(search_url, headers=headers, params=params)
|
205 |
+
response.raise_for_status()
|
206 |
+
results = response.json()["webPages"]["value"]
|
207 |
+
"""execute"""
|
208 |
+
metadata_results = []
|
209 |
+
for result in results:
|
210 |
+
metadata_result = {
|
211 |
+
"snippet": result["snippet"],
|
212 |
+
"title": result["name"],
|
213 |
+
"link": result["url"],
|
214 |
+
}
|
215 |
+
metadata_results.append(metadata_result)
|
216 |
+
return {"meta data": metadata_results}
|
217 |
+
|
218 |
+
def _google_search(self, query: str, **kwargs):
|
219 |
+
"""Initialize search hyperparameters"""
|
220 |
+
api_key = self.api[self.engine_name]["api_key"]
|
221 |
+
cse_id = self.api[self.engine_name]["cse_id"]
|
222 |
+
service = build("customsearch", "v1", developerKey=api_key)
|
223 |
+
"""start searching"""
|
224 |
+
results = (
|
225 |
+
service.cse().list(q=query, cx=cse_id, num=10, **kwargs).execute()["items"]
|
226 |
+
)
|
227 |
+
"""execute"""
|
228 |
+
metadata_results = []
|
229 |
+
for result in results:
|
230 |
+
metadata_result = {
|
231 |
+
"snippet": result["snippet"],
|
232 |
+
"title": result["title"],
|
233 |
+
"link": result["link"],
|
234 |
+
}
|
235 |
+
metadata_results.append(metadata_result)
|
236 |
+
return {"meta data": metadata_results}
|
237 |
+
|
238 |
+
def func(self, agent, **kwargs) -> Dict:
|
239 |
+
query = (
|
240 |
+
agent.long_term_memory[-1]["content"]
|
241 |
+
if len(agent.long_term_memory) > 0
|
242 |
+
else " "
|
243 |
+
)
|
244 |
+
response = agent.LLM.get_response(
|
245 |
+
None,
|
246 |
+
system_prompt=f"Please analyze the provided conversation and identify keywords that can be used for a search engine query. Format the output as <keywords>extracted keywords</keywords>:\nConversation:\n{query}",
|
247 |
+
stream=False,
|
248 |
+
)
|
249 |
+
response = extract(response, "keywords")
|
250 |
+
query = response if response else query
|
251 |
+
|
252 |
+
search_results = self.search[self.engine_name](query=query, **kwargs)
|
253 |
+
information = ""
|
254 |
+
for i in search_results["meta data"][:5]:
|
255 |
+
information += i["snippet"]
|
256 |
+
return {
|
257 |
+
"prompt": "You can refer to the following information to reply:\n"
|
258 |
+
+ information
|
259 |
+
}
|
260 |
+
|
261 |
+
def convert_search_engine_to(self, engine_name):
|
262 |
+
assert engine_name in WebSearchComponent.__ENGINE_NAME__
|
263 |
+
self.engine_name = engine_name
|
264 |
+
|
265 |
+
|
266 |
+
class WebCrawlComponent(ToolComponent):
|
267 |
+
"""Open a single web page for crawling"""
|
268 |
+
|
269 |
+
def __init__(self):
|
270 |
+
super(WebCrawlComponent, self).__init__()
|
271 |
+
|
272 |
+
def func(self, agent_dict) -> Dict:
|
273 |
+
url = agent_dict["url"]
|
274 |
+
print(f"crawling {url} ......")
|
275 |
+
content = ""
|
276 |
+
"""Crawling content from url may need to be carried out according to different websites, such as wiki, baidu, zhihu, etc."""
|
277 |
+
driver = webdriver.Chrome()
|
278 |
+
try:
|
279 |
+
"""open url"""
|
280 |
+
driver.get(url)
|
281 |
+
|
282 |
+
"""wait 20 second"""
|
283 |
+
wait = WebDriverWait(driver, 20)
|
284 |
+
wait.until(EC.presence_of_element_located((By.TAG_NAME, "body")))
|
285 |
+
|
286 |
+
"""crawl code"""
|
287 |
+
page_source = driver.page_source
|
288 |
+
|
289 |
+
"""parse"""
|
290 |
+
soup = BeautifulSoup(page_source, "html.parser")
|
291 |
+
|
292 |
+
"""concatenate"""
|
293 |
+
for paragraph in soup.find_all("p"):
|
294 |
+
content = f"{content}\n{paragraph.get_text()}"
|
295 |
+
except Exception as e:
|
296 |
+
print("Error:", e)
|
297 |
+
finally:
|
298 |
+
"""quit"""
|
299 |
+
driver.quit()
|
300 |
+
return {"content": content.strip()}
|
301 |
+
|
302 |
+
|
303 |
+
class MailComponent(ToolComponent):
|
304 |
+
__VALID_ACTION__ = ["read", "send"]
|
305 |
+
|
306 |
+
def __init__(
|
307 |
+
self, cfg_file: str, default_action: str = "read", name: str = "e-mail"
|
308 |
+
):
|
309 |
+
"""'../config/google_mail.json'"""
|
310 |
+
super(MailComponent, self).__init__(name)
|
311 |
+
self.name = name
|
312 |
+
assert (
|
313 |
+
default_action.lower() in self.__VALID_ACTION__
|
314 |
+
), f"Action `{default_action}` is not allowed! The valid action is in `{self.__VALID_ACTION__}`"
|
315 |
+
self.action = default_action.lower()
|
316 |
+
self.credential = self._login(cfg_file)
|
317 |
+
|
318 |
+
def _login(self, cfg_file: str):
|
319 |
+
SCOPES = [
|
320 |
+
"https://www.googleapis.com/auth/gmail.readonly",
|
321 |
+
"https://www.googleapis.com/auth/gmail.send",
|
322 |
+
]
|
323 |
+
creds = None
|
324 |
+
if os.path.exists("token.json"):
|
325 |
+
print("Login Successfully!")
|
326 |
+
creds = Credentials.from_authorized_user_file("token.json", SCOPES)
|
327 |
+
if not creds or not creds.valid:
|
328 |
+
print("Please authorize in an open browser.")
|
329 |
+
if creds and creds.expired and creds.refresh_token:
|
330 |
+
creds.refresh(Request())
|
331 |
+
else:
|
332 |
+
flow = InstalledAppFlow.from_client_secrets_file(cfg_file, SCOPES)
|
333 |
+
creds = flow.run_local_server(port=0)
|
334 |
+
# Save the credentials for the next run
|
335 |
+
with open("token.json", "w") as token:
|
336 |
+
token.write(creds.to_json())
|
337 |
+
return creds
|
338 |
+
|
339 |
+
def _read(self, mail_dict: dict):
|
340 |
+
credential = self.credential
|
341 |
+
state = mail_dict["state"] if "state" in mail_dict else None
|
342 |
+
time_between = (
|
343 |
+
mail_dict["time_between"] if "time_between" in mail_dict else None
|
344 |
+
)
|
345 |
+
sender_mail = mail_dict["sender_mail"] if "sender_mail" in mail_dict else None
|
346 |
+
only_both = mail_dict["only_both"] if "only_both" in mail_dict else False
|
347 |
+
order_by_time = (
|
348 |
+
mail_dict["order_by_time"] if "order_by_time" in mail_dict else "descend"
|
349 |
+
)
|
350 |
+
include_word = (
|
351 |
+
mail_dict["include_word"] if "include_word" in mail_dict else None
|
352 |
+
)
|
353 |
+
exclude_word = (
|
354 |
+
mail_dict["exclude_word"] if "exclude_word" in mail_dict else None
|
355 |
+
)
|
356 |
+
MAX_SEARCH_CNT = (
|
357 |
+
mail_dict["MAX_SEARCH_CNT"] if "MAX_SEARCH_CNT" in mail_dict else 50
|
358 |
+
)
|
359 |
+
number = mail_dict["number"] if "number" in mail_dict else 10
|
360 |
+
if state is None:
|
361 |
+
state = "all"
|
362 |
+
if time_between is not None:
|
363 |
+
assert isinstance(time_between, tuple)
|
364 |
+
assert len(time_between) == 2
|
365 |
+
assert state in ["all", "unread", "read", "sent"]
|
366 |
+
if only_both:
|
367 |
+
assert sender_mail is not None
|
368 |
+
if sender_mail is not None:
|
369 |
+
assert isinstance(sender_mail, str)
|
370 |
+
assert credential
|
371 |
+
assert order_by_time in ["descend", "ascend"]
|
372 |
+
|
373 |
+
def generate_query():
|
374 |
+
query = ""
|
375 |
+
if state in ["unread", "read"]:
|
376 |
+
query = f"is:{state}"
|
377 |
+
if state in ["sent"]:
|
378 |
+
query = f"in:{state}"
|
379 |
+
if only_both:
|
380 |
+
query = f"{query} from:{sender_mail} OR to:{sender_mail}"
|
381 |
+
if sender_mail is not None and not only_both:
|
382 |
+
query = f"{query} from:({sender_mail})"
|
383 |
+
if include_word is not None:
|
384 |
+
query = f"{query} {include_word}"
|
385 |
+
if exclude_word is not None:
|
386 |
+
query = f"{query} -{exclude_word}"
|
387 |
+
if time_between is not None:
|
388 |
+
TIME_FORMAT = "%Y/%m/%d"
|
389 |
+
t1, t2 = time_between
|
390 |
+
if t1 == "now":
|
391 |
+
t1 = datetime.now().strftime(TIME_FORMAT)
|
392 |
+
if t2 == "now":
|
393 |
+
t2 = datetime.now().strftime(TIME_FORMAT)
|
394 |
+
if isinstance(t1, str) and isinstance(t2, str):
|
395 |
+
t1 = datetime.strptime(t1, TIME_FORMAT)
|
396 |
+
t2 = datetime.strptime(t2, TIME_FORMAT)
|
397 |
+
elif isinstance(t1, str) and isinstance(t2, int):
|
398 |
+
t1 = datetime.strptime(t1, TIME_FORMAT)
|
399 |
+
t2 = t1 + timedelta(days=t2)
|
400 |
+
elif isinstance(t1, int) and isinstance(t2, str):
|
401 |
+
t2 = datetime.strptime(t2, TIME_FORMAT)
|
402 |
+
t1 = t2 + timedelta(days=t1)
|
403 |
+
else:
|
404 |
+
assert False, "invalid time"
|
405 |
+
if t1 > t2:
|
406 |
+
t1, t2 = t2, t1
|
407 |
+
query = f"{query} after:{t1.strftime(TIME_FORMAT)} before:{t2.strftime(TIME_FORMAT)}"
|
408 |
+
return query.strip()
|
409 |
+
|
410 |
+
def sort_by_time(data: List[Dict]):
|
411 |
+
if order_by_time == "descend":
|
412 |
+
reverse = True
|
413 |
+
else:
|
414 |
+
reverse = False
|
415 |
+
sorted_data = sorted(
|
416 |
+
data,
|
417 |
+
key=lambda x: datetime.strptime(x["time"], "%Y-%m-%d %H:%M:%S"),
|
418 |
+
reverse=reverse,
|
419 |
+
)
|
420 |
+
return sorted_data
|
421 |
+
|
422 |
+
try:
|
423 |
+
service = build("gmail", "v1", credentials=credential)
|
424 |
+
results = (
|
425 |
+
service.users()
|
426 |
+
.messages()
|
427 |
+
.list(userId="me", labelIds=["INBOX"], q=generate_query())
|
428 |
+
.execute()
|
429 |
+
)
|
430 |
+
|
431 |
+
messages = results.get("messages", [])
|
432 |
+
email_data = list()
|
433 |
+
|
434 |
+
if not messages:
|
435 |
+
print("No eligible emails.")
|
436 |
+
return None
|
437 |
+
else:
|
438 |
+
pbar = tqdm(total=min(MAX_SEARCH_CNT, len(messages)))
|
439 |
+
for cnt, message in enumerate(messages):
|
440 |
+
pbar.update(1)
|
441 |
+
if cnt >= MAX_SEARCH_CNT:
|
442 |
+
break
|
443 |
+
msg = (
|
444 |
+
service.users()
|
445 |
+
.messages()
|
446 |
+
.get(
|
447 |
+
userId="me",
|
448 |
+
id=message["id"],
|
449 |
+
format="full",
|
450 |
+
metadataHeaders=None,
|
451 |
+
)
|
452 |
+
.execute()
|
453 |
+
)
|
454 |
+
|
455 |
+
subject = ""
|
456 |
+
for header in msg["payload"]["headers"]:
|
457 |
+
if header["name"] == "Subject":
|
458 |
+
subject = header["value"]
|
459 |
+
break
|
460 |
+
|
461 |
+
sender = ""
|
462 |
+
for header in msg["payload"]["headers"]:
|
463 |
+
if header["name"] == "From":
|
464 |
+
sender = re.findall(
|
465 |
+
r"\b[\w\.-]+@[\w\.-]+\.\w+\b", header["value"]
|
466 |
+
)[0]
|
467 |
+
break
|
468 |
+
body = ""
|
469 |
+
if "parts" in msg["payload"]:
|
470 |
+
for part in msg["payload"]["parts"]:
|
471 |
+
if part["mimeType"] == "text/plain":
|
472 |
+
data = part["body"]["data"]
|
473 |
+
body = base64.urlsafe_b64decode(data).decode("utf-8")
|
474 |
+
break
|
475 |
+
|
476 |
+
email_info = {
|
477 |
+
"sender": sender,
|
478 |
+
"time": datetime.fromtimestamp(
|
479 |
+
int(msg["internalDate"]) / 1000
|
480 |
+
).strftime("%Y-%m-%d %H:%M:%S"),
|
481 |
+
"subject": subject,
|
482 |
+
"body": body,
|
483 |
+
}
|
484 |
+
email_data.append(email_info)
|
485 |
+
pbar.close()
|
486 |
+
email_data = sort_by_time(email_data)[0:number]
|
487 |
+
return {"results": email_data}
|
488 |
+
except Exception as e:
|
489 |
+
print(e)
|
490 |
+
return None
|
491 |
+
|
492 |
+
def _send(self, mail_dict: dict):
|
493 |
+
recipient_mail = mail_dict["recipient_mail"]
|
494 |
+
subject = mail_dict["subject"]
|
495 |
+
body = mail_dict["body"]
|
496 |
+
credential = self.credential
|
497 |
+
service = build("gmail", "v1", credentials=credential)
|
498 |
+
|
499 |
+
message = MIMEMultipart()
|
500 |
+
message["to"] = recipient_mail
|
501 |
+
message["subject"] = subject
|
502 |
+
|
503 |
+
message.attach(MIMEText(body, "plain"))
|
504 |
+
|
505 |
+
raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode("utf-8")
|
506 |
+
try:
|
507 |
+
message = (
|
508 |
+
service.users()
|
509 |
+
.messages()
|
510 |
+
.send(userId="me", body={"raw": raw_message})
|
511 |
+
.execute()
|
512 |
+
)
|
513 |
+
return {"state": True}
|
514 |
+
except HttpError as error:
|
515 |
+
print(error)
|
516 |
+
return {"state": False}
|
517 |
+
|
518 |
+
def func(self, mail_dict: dict):
|
519 |
+
if "action" in mail_dict:
|
520 |
+
assert mail_dict["action"].lower() in self.__VALID_ACTION__
|
521 |
+
self.action = mail_dict["action"]
|
522 |
+
functions = {"read": self._read, "send": self._send}
|
523 |
+
return functions[self.action](mail_dict)
|
524 |
+
|
525 |
+
def convert_action_to(self, action_name: str):
|
526 |
+
assert (
|
527 |
+
action_name.lower() in self.__VALID_ACTION__
|
528 |
+
), f"Action `{action_name}` is not allowed! The valid action is in `{self.__VALID_ACTION__}`"
|
529 |
+
self.action = action_name.lower()
|
530 |
+
|
531 |
+
|
532 |
+
class WeatherComponet(ToolComponent):
|
533 |
+
def __init__(self, api_key, name="weather", TIME_FORMAT="%Y-%m-%d"):
|
534 |
+
super(WeatherComponet, self).__init__(name)
|
535 |
+
self.name = name
|
536 |
+
self.TIME_FORMAT = TIME_FORMAT
|
537 |
+
self.api_key = api_key
|
538 |
+
|
539 |
+
def _parse(self, data):
|
540 |
+
dict_data: dict = {}
|
541 |
+
for item in data["data"]:
|
542 |
+
date = item["datetime"]
|
543 |
+
dict_data[date] = {}
|
544 |
+
if "weather" in item:
|
545 |
+
dict_data[date]["description"] = item["weather"]["description"]
|
546 |
+
mapping = {
|
547 |
+
"temp": "temperature",
|
548 |
+
"max_temp": "max_temperature",
|
549 |
+
"min_temp": "min_temperature",
|
550 |
+
"precip": "accumulated_precipitation",
|
551 |
+
}
|
552 |
+
for key in ["temp", "max_temp", "min_temp", "precip"]:
|
553 |
+
if key in item:
|
554 |
+
dict_data[date][mapping[key]] = item[key]
|
555 |
+
return dict_data
|
556 |
+
|
557 |
+
def _query(self, city_name, country_code, start_date, end_date):
|
558 |
+
"""https://www.weatherbit.io/api/historical-weather-daily"""
|
559 |
+
# print(datetime.strftime(start_date, self.TIME_FORMAT), datetime.strftime(datetime.now(), self.TIME_FORMAT), end_date, datetime.strftime(datetime.now()+timedelta(days=1), self.TIME_FORMAT))
|
560 |
+
if start_date == datetime.strftime(
|
561 |
+
datetime.now(), self.TIME_FORMAT
|
562 |
+
) and end_date == datetime.strftime(
|
563 |
+
datetime.now() + timedelta(days=1), self.TIME_FORMAT
|
564 |
+
):
|
565 |
+
"""today"""
|
566 |
+
url = f"https://api.weatherbit.io/v2.0/current?city={city_name}&country={country_code}&key={self.api_key}"
|
567 |
+
else:
|
568 |
+
url = f"https://api.weatherbit.io/v2.0/history/daily?&city={city_name}&country={country_code}&start_date={start_date}&end_date={end_date}&key={self.api_key}"
|
569 |
+
response = requests.get(url)
|
570 |
+
data = response.json()
|
571 |
+
return self._parse(data)
|
572 |
+
|
573 |
+
def func(self, weather_dict: Dict) -> Dict:
|
574 |
+
TIME_FORMAT = self.TIME_FORMAT
|
575 |
+
# Beijing, Shanghai
|
576 |
+
city_name = weather_dict["city_name"]
|
577 |
+
# CN, US
|
578 |
+
country_code = weather_dict["country_code"]
|
579 |
+
# 2020-02-02
|
580 |
+
start_date = datetime.strftime(
|
581 |
+
datetime.strptime(weather_dict["start_date"], self.TIME_FORMAT),
|
582 |
+
self.TIME_FORMAT,
|
583 |
+
)
|
584 |
+
end_date = weather_dict["end_date"] if "end_date" in weather_dict else None
|
585 |
+
if end_date is None:
|
586 |
+
end_date = datetime.strftime(
|
587 |
+
datetime.strptime(start_date, TIME_FORMAT) + timedelta(days=-1),
|
588 |
+
TIME_FORMAT,
|
589 |
+
)
|
590 |
+
else:
|
591 |
+
end_date = datetime.strftime(
|
592 |
+
datetime.strptime(weather_dict["end_date"], self.TIME_FORMAT),
|
593 |
+
self.TIME_FORMAT,
|
594 |
+
)
|
595 |
+
if datetime.strptime(start_date, TIME_FORMAT) > datetime.strptime(
|
596 |
+
end_date, TIME_FORMAT
|
597 |
+
):
|
598 |
+
start_date, end_date = end_date, start_date
|
599 |
+
assert start_date != end_date
|
600 |
+
return self._query(city_name, country_code, start_date, end_date)
|
601 |
+
|
602 |
+
|
603 |
+
class TranslateComponent(ToolComponent):
|
604 |
+
__SUPPORT_LANGUAGE__ = [
|
605 |
+
"af",
|
606 |
+
"am",
|
607 |
+
"ar",
|
608 |
+
"as",
|
609 |
+
"az",
|
610 |
+
"ba",
|
611 |
+
"bg",
|
612 |
+
"bn",
|
613 |
+
"bo",
|
614 |
+
"bs",
|
615 |
+
"ca",
|
616 |
+
"cs",
|
617 |
+
"cy",
|
618 |
+
"da",
|
619 |
+
"de",
|
620 |
+
"dsb",
|
621 |
+
"dv",
|
622 |
+
"el",
|
623 |
+
"en",
|
624 |
+
"es",
|
625 |
+
"et",
|
626 |
+
"eu",
|
627 |
+
"fa",
|
628 |
+
"fi",
|
629 |
+
"fil",
|
630 |
+
"fj",
|
631 |
+
"fo",
|
632 |
+
"fr",
|
633 |
+
"fr-CA",
|
634 |
+
"ga",
|
635 |
+
"gl",
|
636 |
+
"gom",
|
637 |
+
"gu",
|
638 |
+
"ha",
|
639 |
+
"he",
|
640 |
+
"hi",
|
641 |
+
"hr",
|
642 |
+
"hsb",
|
643 |
+
"ht",
|
644 |
+
"hu",
|
645 |
+
"hy",
|
646 |
+
"id",
|
647 |
+
"ig",
|
648 |
+
"ikt",
|
649 |
+
"is",
|
650 |
+
"it",
|
651 |
+
"iu",
|
652 |
+
"iu-Latn",
|
653 |
+
"ja",
|
654 |
+
"ka",
|
655 |
+
"kk",
|
656 |
+
"km",
|
657 |
+
"kmr",
|
658 |
+
"kn",
|
659 |
+
"ko",
|
660 |
+
"ku",
|
661 |
+
"ky",
|
662 |
+
"ln",
|
663 |
+
"lo",
|
664 |
+
"lt",
|
665 |
+
"lug",
|
666 |
+
"lv",
|
667 |
+
"lzh",
|
668 |
+
"mai",
|
669 |
+
"mg",
|
670 |
+
"mi",
|
671 |
+
"mk",
|
672 |
+
"ml",
|
673 |
+
"mn-Cyrl",
|
674 |
+
"mn-Mong",
|
675 |
+
"mr",
|
676 |
+
"ms",
|
677 |
+
"mt",
|
678 |
+
"mww",
|
679 |
+
"my",
|
680 |
+
"nb",
|
681 |
+
"ne",
|
682 |
+
"nl",
|
683 |
+
"nso",
|
684 |
+
"nya",
|
685 |
+
"or",
|
686 |
+
"otq",
|
687 |
+
"pa",
|
688 |
+
"pl",
|
689 |
+
"prs",
|
690 |
+
"ps",
|
691 |
+
"pt",
|
692 |
+
"pt-PT",
|
693 |
+
"ro",
|
694 |
+
"ru",
|
695 |
+
"run",
|
696 |
+
"rw",
|
697 |
+
"sd",
|
698 |
+
"si",
|
699 |
+
"sk",
|
700 |
+
"sl",
|
701 |
+
"sm",
|
702 |
+
"sn",
|
703 |
+
"so",
|
704 |
+
"sq",
|
705 |
+
"sr-Cyrl",
|
706 |
+
"sr-Latn",
|
707 |
+
"st",
|
708 |
+
"sv",
|
709 |
+
"sw",
|
710 |
+
"ta",
|
711 |
+
"te",
|
712 |
+
"th",
|
713 |
+
"ti",
|
714 |
+
"tk",
|
715 |
+
"tlh-Latn",
|
716 |
+
"tlh-Piqd",
|
717 |
+
"tn",
|
718 |
+
"to",
|
719 |
+
"tr",
|
720 |
+
"tt",
|
721 |
+
"ty",
|
722 |
+
"ug",
|
723 |
+
"uk",
|
724 |
+
"ur",
|
725 |
+
"uz",
|
726 |
+
"vi",
|
727 |
+
"xh",
|
728 |
+
"yo",
|
729 |
+
"yua",
|
730 |
+
"yue",
|
731 |
+
"zh-Hans",
|
732 |
+
"zh-Hant",
|
733 |
+
"zu",
|
734 |
+
]
|
735 |
+
|
736 |
+
def __init__(
|
737 |
+
self, api_key, location, default_target_language="zh-cn", name="translate"
|
738 |
+
):
|
739 |
+
super(TranslateComponent, self).__init__(name)
|
740 |
+
self.name = name
|
741 |
+
self.api_key = api_key
|
742 |
+
self.location = location
|
743 |
+
self.default_target_language = default_target_language
|
744 |
+
|
745 |
+
def func(self, translate_dict: Dict) -> Dict:
|
746 |
+
content = translate_dict["content"]
|
747 |
+
target_language = self.default_target_language
|
748 |
+
if "target_language" in translate_dict:
|
749 |
+
target_language = translate_dict["target_language"]
|
750 |
+
assert (
|
751 |
+
target_language in self.__SUPPORT_LANGUAGE__
|
752 |
+
), f"language `{target_language}` is not supported."
|
753 |
+
|
754 |
+
endpoint = "https://api.cognitive.microsofttranslator.com"
|
755 |
+
|
756 |
+
path = "/translate"
|
757 |
+
constructed_url = endpoint + path
|
758 |
+
|
759 |
+
params = {"api-version": "3.0", "to": target_language}
|
760 |
+
|
761 |
+
headers = {
|
762 |
+
"Ocp-Apim-Subscription-Key": self.api_key,
|
763 |
+
"Ocp-Apim-Subscription-Region": self.location,
|
764 |
+
"Content-type": "application/json",
|
765 |
+
"X-ClientTraceId": str(uuid.uuid4()),
|
766 |
+
}
|
767 |
+
|
768 |
+
body = [{"text": content}]
|
769 |
+
|
770 |
+
request = requests.post(
|
771 |
+
constructed_url, params=params, headers=headers, json=body
|
772 |
+
)
|
773 |
+
response = request.json()
|
774 |
+
response = json.dumps(
|
775 |
+
response,
|
776 |
+
sort_keys=True,
|
777 |
+
ensure_ascii=False,
|
778 |
+
indent=4,
|
779 |
+
separators=(",", ": "),
|
780 |
+
)
|
781 |
+
response = eval(response)
|
782 |
+
return {"result": response[0]["translations"][0]["text"]}
|
783 |
+
|
784 |
+
|
785 |
+
class APIComponent(ToolComponent):
|
786 |
+
def __init__(self):
|
787 |
+
super(APIComponent, self).__init__()
|
788 |
+
|
789 |
+
def func(self, agent) -> Dict:
|
790 |
+
pass
|
791 |
+
|
792 |
+
|
793 |
+
class FunctionComponent(ToolComponent):
|
794 |
+
def __init__(
|
795 |
+
self,
|
796 |
+
functions,
|
797 |
+
function_call="auto",
|
798 |
+
response_type="response",
|
799 |
+
your_function=None,
|
800 |
+
):
|
801 |
+
super().__init__()
|
802 |
+
self.functions = functions
|
803 |
+
self.function_call = function_call
|
804 |
+
self.parameters = {}
|
805 |
+
self.available_functions = {}
|
806 |
+
self.response_type = response_type
|
807 |
+
if your_function:
|
808 |
+
function_name = your_function["name"]
|
809 |
+
function_content = your_function["content"]
|
810 |
+
exec(function_content)
|
811 |
+
self.available_functions[function_name] = eval(function_name)
|
812 |
+
|
813 |
+
for function in self.functions:
|
814 |
+
self.parameters[function["name"]] = list(
|
815 |
+
function["parameters"]["properties"].keys()
|
816 |
+
)
|
817 |
+
self.available_functions[function["name"]] = eval(function["name"])
|
818 |
+
|
819 |
+
def func(self, agent):
|
820 |
+
messages = agent.long_term_memory
|
821 |
+
outputdict = {}
|
822 |
+
query = agent.long_term_memory[-1].content if len(agent.long_term_memory) > 0 else " "
|
823 |
+
relevant_history = get_relevant_history(
|
824 |
+
query,
|
825 |
+
agent.long_term_memory[:-1],
|
826 |
+
agent.chat_embeddings[:-1],
|
827 |
+
)
|
828 |
+
response = agent.LLM.get_response(
|
829 |
+
messages,
|
830 |
+
None,
|
831 |
+
functions=self.functions,
|
832 |
+
stream=False,
|
833 |
+
function_call=self.function_call,
|
834 |
+
relevant_history=relevant_history,
|
835 |
+
)
|
836 |
+
response_message = response
|
837 |
+
if response_message.get("function_call"):
|
838 |
+
function_name = response_message["function_call"]["name"]
|
839 |
+
fuction_to_call = self.available_functions[function_name]
|
840 |
+
function_args = json.loads(response_message["function_call"]["arguments"])
|
841 |
+
input_args = {}
|
842 |
+
for args_name in self.parameters[function_name]:
|
843 |
+
input_args[args_name] = function_args.get(args_name)
|
844 |
+
function_response = fuction_to_call(**input_args)
|
845 |
+
if self.response_type == "response":
|
846 |
+
outputdict["response"] = function_response
|
847 |
+
elif self.response_type == "prompt":
|
848 |
+
outputdict["prompt"] = function_response
|
849 |
+
|
850 |
+
return outputdict
|
851 |
+
|
852 |
+
|
853 |
+
class CodeComponent(ToolComponent):
|
854 |
+
def __init__(self, file_name, keyword) -> None:
|
855 |
+
super().__init__()
|
856 |
+
self.file_name = file_name
|
857 |
+
self.keyword = keyword
|
858 |
+
self.system_prompt = (
|
859 |
+
"you need to extract the modified code as completely as possible."
|
860 |
+
)
|
861 |
+
self.last_prompt = (
|
862 |
+
f"Please strictly adhere to the following format for outputting: \n"
|
863 |
+
)
|
864 |
+
self.last_prompt += (
|
865 |
+
f"<{self.keyword}> the content you need to extract </{self.keyword}>"
|
866 |
+
)
|
867 |
+
|
868 |
+
def func(self, agent):
|
869 |
+
response = agent.LLM.get_response(
|
870 |
+
agent.long_term_memory,
|
871 |
+
self.system_prompt,
|
872 |
+
self.last_prompt,
|
873 |
+
stream=False,
|
874 |
+
)
|
875 |
+
code = extract(response, self.keyword)
|
876 |
+
code = code if code else response
|
877 |
+
os.makedirs("output_code", exist_ok=True)
|
878 |
+
file_name = "output_code/" + self.file_name
|
879 |
+
codes = code.split("\n")
|
880 |
+
if codes[0] == "```python":
|
881 |
+
codes.remove(codes[0])
|
882 |
+
if codes[-1] == "```":
|
883 |
+
codes.remove(codes[-1])
|
884 |
+
code = "\n".join(codes)
|
885 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
886 |
+
f.write(code)
|
887 |
+
return {}
|
agents/Component/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .ExtraComponent import *
|
2 |
+
from .PromptComponent import *
|
3 |
+
from .ToolComponent import *
|
agents/Component/__pycache__/ExtraComponent.cpython-38.pyc
ADDED
Binary file (4.03 kB). View file
|
|
agents/Component/__pycache__/PromptComponent.cpython-38.pyc
ADDED
Binary file (6.15 kB). View file
|
|
agents/Component/__pycache__/ToolComponent.cpython-38.pyc
ADDED
Binary file (22.1 kB). View file
|
|
agents/Component/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (202 Bytes). View file
|
|
agents/Environment/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base_environment import Environment
|
agents/Environment/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (173 Bytes). View file
|
|
agents/Environment/__pycache__/base_environment.cpython-38.pyc
ADDED
Binary file (4.28 kB). View file
|
|
agents/Environment/base_environment.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import get_relevant_history, get_embedding
|
2 |
+
import torch
|
3 |
+
from LLM.base_LLM import *
|
4 |
+
from Memory import Memory
|
5 |
+
from Prompt import *
|
6 |
+
import json
|
7 |
+
class Environment:
|
8 |
+
"""
|
9 |
+
The place where the agent activities, responsible for storing some shared memories
|
10 |
+
"""
|
11 |
+
def __init__(self, config) -> None:
|
12 |
+
self.shared_memory = {"long_term_memory": [], "short_term_memory": None}
|
13 |
+
self.agents = None
|
14 |
+
|
15 |
+
self.summary_system_prompt = {}
|
16 |
+
self.summary_last_prompt = {}
|
17 |
+
self.environment_prompt = {}
|
18 |
+
self.environment_type = config["environment_type"] if "environment_type" in config else "cooperative"
|
19 |
+
self.current_chat_history_idx = 0
|
20 |
+
self.LLMs = {}
|
21 |
+
|
22 |
+
# 初始化每个state 的summary 方法
|
23 |
+
# Initialize the summary method for each state
|
24 |
+
for state_name, state_dict in config["states"].items():
|
25 |
+
if state_name != "end_state":
|
26 |
+
self.summary_system_prompt[state_name] = (
|
27 |
+
state_dict["summary_system_prompt"]
|
28 |
+
if "summary_system_prompt" in state_dict
|
29 |
+
else eval(Default_environment_summary_system_prompt)
|
30 |
+
)
|
31 |
+
|
32 |
+
self.summary_last_prompt[state_name] = (
|
33 |
+
state_dict["summary_last_prompt"]
|
34 |
+
if "summary_last_prompt" in state_dict
|
35 |
+
else eval(Default_environment_summary_last_prompt)
|
36 |
+
)
|
37 |
+
|
38 |
+
self.environment_prompt[state_name] = (
|
39 |
+
state_dict["environment_prompt"]
|
40 |
+
if "environment_prompt" in state_dict
|
41 |
+
else " "
|
42 |
+
)
|
43 |
+
self.LLMs[state_name] = init_LLM(f"logs/{state_name}",**state_dict)
|
44 |
+
self.roles_to_names = None
|
45 |
+
self.names_to_roles = None
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def from_config(cls, config_path):
|
49 |
+
with open(config_path) as f:
|
50 |
+
config = json.load(f)
|
51 |
+
return cls(config)
|
52 |
+
|
53 |
+
def summary(self, current_state):
|
54 |
+
"""
|
55 |
+
Summarize the situation in the current environment every once in a while
|
56 |
+
"""
|
57 |
+
MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
|
58 |
+
current_state_name = current_state.name
|
59 |
+
|
60 |
+
query = self.shared_memory["long_term_memory"][-1].content
|
61 |
+
relevant_history = get_relevant_history(
|
62 |
+
query,
|
63 |
+
self.shared_memory["long_term_memory"][:-1],
|
64 |
+
self.shared_memory["chat_embeddings"][:-1],
|
65 |
+
)
|
66 |
+
|
67 |
+
relevant_history = Memory.get_chat_history(relevant_history)
|
68 |
+
chat_history = Memory.get_chat_history(
|
69 |
+
self.shared_memory["long_term_memory"][-MAX_CHAT_HISTORY + 1 :]
|
70 |
+
)
|
71 |
+
summary = self.shared_memory["short_term_memory"]
|
72 |
+
|
73 |
+
|
74 |
+
# system prompt = environment prompt + current memory + system prompt
|
75 |
+
# current_memory = summary + chat history + relevant history
|
76 |
+
current_memory = eval(Environment_summary_memory)
|
77 |
+
environment_prompt = self.environment_prompt[current_state_name]
|
78 |
+
summary_system_prompt = self.summary_system_prompt[current_state_name]
|
79 |
+
|
80 |
+
environment_summary_system_prompt = eval(Environment_summary_system_prompt)
|
81 |
+
response = self.LLMs[current_state_name].get_response(None, environment_summary_system_prompt, stream=False)
|
82 |
+
return response
|
83 |
+
|
84 |
+
def update_memory(self, memory, current_state):
|
85 |
+
"""
|
86 |
+
update chat embbedings and long term memory,short term memory,agents long term memory
|
87 |
+
"""
|
88 |
+
MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
|
89 |
+
self.shared_memory["long_term_memory"].append(memory)
|
90 |
+
current_embedding = get_embedding(memory.content)
|
91 |
+
if "chat_embeddings" not in self.shared_memory:
|
92 |
+
self.shared_memory["chat_embeddings"] = current_embedding
|
93 |
+
else:
|
94 |
+
self.shared_memory["chat_embeddings"] = torch.cat(
|
95 |
+
[self.shared_memory["chat_embeddings"], current_embedding], dim=0
|
96 |
+
)
|
97 |
+
if len(self.shared_memory["long_term_memory"]) % MAX_CHAT_HISTORY == 0:
|
98 |
+
summary = self.summary(current_state)
|
99 |
+
self.shared_memory["short_term_memory"] = summary
|
100 |
+
|
101 |
+
self.agents[memory.send_name].update_memory(memory)
|
102 |
+
|
103 |
+
|
104 |
+
def _get_agent_last_conversation_idx(self,agent,current_long_term_memory):
|
105 |
+
last_conversation_idx = -1
|
106 |
+
for i, history in enumerate(current_long_term_memory):
|
107 |
+
if history.send_name == agent.name:
|
108 |
+
last_conversation_idx = i
|
109 |
+
return last_conversation_idx
|
110 |
+
|
111 |
+
|
112 |
+
def _get_agent_new_memory(self,agent,current_long_term_memory):
|
113 |
+
# get new conversation
|
114 |
+
last_conversation_idx = self._get_agent_last_conversation_idx(agent,current_long_term_memory)
|
115 |
+
|
116 |
+
if last_conversation_idx == -1:
|
117 |
+
new_conversation =current_long_term_memory
|
118 |
+
elif (
|
119 |
+
last_conversation_idx
|
120 |
+
== len(current_long_term_memory) - 1
|
121 |
+
):
|
122 |
+
new_conversation = []
|
123 |
+
else:
|
124 |
+
new_conversation = current_long_term_memory[
|
125 |
+
last_conversation_idx + 1 :
|
126 |
+
]
|
127 |
+
|
128 |
+
# get chat history from new conversation
|
129 |
+
return Memory.get_chat_history(new_conversation)
|
130 |
+
|
131 |
+
|
132 |
+
def _observe(self,agent):
|
133 |
+
MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
|
134 |
+
current_state = agent.current_state
|
135 |
+
current_role = agent.state_roles[current_state.name]
|
136 |
+
current_component_dict = current_state.components[current_role]
|
137 |
+
|
138 |
+
# cooperative:Sharing information between different states ; competive: No information is shared between different states
|
139 |
+
current_chat_history_idx = self.current_chat_history_idx if self.environment_type == "competive" else 0
|
140 |
+
current_long_term_memory = self.shared_memory["long_term_memory"][current_chat_history_idx:]
|
141 |
+
current_chat_embbedings = self.shared_memory["chat_embeddings"][current_chat_history_idx:]
|
142 |
+
|
143 |
+
|
144 |
+
# relevant_memory
|
145 |
+
query = current_long_term_memory[-1].content
|
146 |
+
|
147 |
+
relevant_memory = get_relevant_history(
|
148 |
+
query,
|
149 |
+
current_long_term_memory[:-1],
|
150 |
+
current_chat_embbedings[:-1],
|
151 |
+
)
|
152 |
+
relevant_memory = Memory.get_chat_history(relevant_memory,agent.name)
|
153 |
+
|
154 |
+
relevant_memory = eval(Agent_observe_relevant_memory)
|
155 |
+
agent.relevant_memory = relevant_memory
|
156 |
+
|
157 |
+
|
158 |
+
# get chat history from new conversation
|
159 |
+
conversations = self._get_agent_new_memory(agent,current_long_term_memory)
|
160 |
+
|
161 |
+
# memory = relevant_memory + summary + history + query
|
162 |
+
query = current_long_term_memory[-1]
|
163 |
+
current_memory = eval(Agent_observe_memory)
|
164 |
+
|
165 |
+
return {"role": "user", "content": current_memory}
|
166 |
+
|
167 |
+
|
agents/LLM/__init__.py
ADDED
File without changes
|
agents/LLM/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (113 Bytes). View file
|
|
agents/LLM/__pycache__/base_LLM.cpython-38.pyc
ADDED
Binary file (3.61 kB). View file
|
|
agents/LLM/base_LLM.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractclassmethod
|
2 |
+
import openai
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from Memory import Memory
|
6 |
+
from utils import save_logs
|
7 |
+
|
8 |
+
class LLM:
|
9 |
+
def __init__(self) -> None:
|
10 |
+
pass
|
11 |
+
|
12 |
+
@abstractclassmethod
|
13 |
+
def get_response():
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class OpenAILLM(LLM):
|
18 |
+
def __init__(self,**kwargs) -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.MAX_CHAT_HISTORY = eval(
|
21 |
+
os.environ["MAX_CHAT_HISTORY"]) if "MAX_CHAT_HISTORY" in os.environ else 10
|
22 |
+
|
23 |
+
self.model = kwargs["model"] if "model" in kwargs else "gpt-3.5-turbo-16k-0613"
|
24 |
+
self.temperature = kwargs["temperature"] if "temperature" in kwargs else 0.3
|
25 |
+
self.log_path = kwargs["log_path"] if "log_path" in kwargs else "logs"
|
26 |
+
|
27 |
+
|
28 |
+
def get_stream(self,response, log_path, messages):
|
29 |
+
ans = ""
|
30 |
+
for res in response:
|
31 |
+
if res:
|
32 |
+
r = (res.choices[0]["delta"].get("content")
|
33 |
+
if res.choices[0]["delta"].get("content") else "")
|
34 |
+
ans += r
|
35 |
+
yield r
|
36 |
+
|
37 |
+
save_logs(log_path, messages, ans)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def get_response(self,
|
42 |
+
chat_history,
|
43 |
+
system_prompt,
|
44 |
+
last_prompt=None,
|
45 |
+
stream=False,
|
46 |
+
functions=None,
|
47 |
+
function_call="auto",
|
48 |
+
WAIT_TIME=20,
|
49 |
+
**kwargs):
|
50 |
+
"""
|
51 |
+
return LLM's response
|
52 |
+
"""
|
53 |
+
openai.api_key = os.environ["API_KEY"]
|
54 |
+
if "PROXY" in os.environ:
|
55 |
+
assert "http:" in os.environ["PROXY"] or "socks" in os.environ["PROXY"],"PROXY error,PROXY must be http or socks"
|
56 |
+
openai.proxy = os.environ["PROXY"]
|
57 |
+
if "API_BASE" in os.environ:
|
58 |
+
openai.api_base = os.environ["API_BASE"]
|
59 |
+
active_mode = True if ("ACTIVE_MODE" in os.environ and os.environ["ACTIVE_MODE"] == "0") else False
|
60 |
+
model = self.model
|
61 |
+
temperature = self.temperature
|
62 |
+
|
63 |
+
|
64 |
+
if active_mode:
|
65 |
+
system_prompt = system_prompt + "Please keep your reply as concise as possible,Within three sentences, the total word count should not exceed 30"
|
66 |
+
|
67 |
+
messages = [{
|
68 |
+
"role": "system",
|
69 |
+
"content": system_prompt
|
70 |
+
}] if system_prompt else []
|
71 |
+
|
72 |
+
if chat_history:
|
73 |
+
if len(chat_history) > self.MAX_CHAT_HISTORY:
|
74 |
+
chat_history = chat_history[- self.MAX_CHAT_HISTORY:]
|
75 |
+
if isinstance(chat_history[0],dict):
|
76 |
+
messages += chat_history
|
77 |
+
elif isinstance(chat_history[0],Memory):
|
78 |
+
messages += [memory.get_gpt_message("user") for memory in chat_history]
|
79 |
+
|
80 |
+
if last_prompt:
|
81 |
+
if active_mode:
|
82 |
+
last_prompt = last_prompt + "Please keep your reply as concise as possible,Within three sentences, the total word count should not exceed 30"
|
83 |
+
# messages += [{"role": "system", "content": f"{last_prompt}"}]
|
84 |
+
messages[-1]["content"] += last_prompt
|
85 |
+
|
86 |
+
|
87 |
+
while True:
|
88 |
+
try:
|
89 |
+
if functions:
|
90 |
+
response = openai.ChatCompletion.create(
|
91 |
+
model=model,
|
92 |
+
messages=messages,
|
93 |
+
functions=functions,
|
94 |
+
function_call=function_call,
|
95 |
+
temperature=temperature,
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
response = openai.ChatCompletion.create(
|
99 |
+
model=model,
|
100 |
+
messages=messages,
|
101 |
+
temperature=temperature,
|
102 |
+
stream=stream)
|
103 |
+
break
|
104 |
+
except Exception as e:
|
105 |
+
print(e)
|
106 |
+
if "maximum context length is" in str(e):
|
107 |
+
assert False, "exceed max length"
|
108 |
+
break
|
109 |
+
else:
|
110 |
+
print(f"Please wait {WAIT_TIME} seconds and resend later ...")
|
111 |
+
time.sleep(WAIT_TIME)
|
112 |
+
|
113 |
+
if functions:
|
114 |
+
save_logs(self.log_path, messages, response)
|
115 |
+
return response.choices[0].message
|
116 |
+
elif stream:
|
117 |
+
return self.get_stream(response, self.log_path, messages)
|
118 |
+
else:
|
119 |
+
save_logs(self.log_path, messages, response)
|
120 |
+
return response.choices[0].message["content"]
|
121 |
+
|
122 |
+
|
123 |
+
def init_LLM(default_log_path,**kwargs):
|
124 |
+
LLM_type = kwargs["LLM_type"] if "LLM_type" in kwargs else "OpenAI"
|
125 |
+
log_path = kwargs["log_path"] if "log_path" in kwargs else default_log_path
|
126 |
+
if LLM_type == "OpenAI":
|
127 |
+
LLM = (
|
128 |
+
OpenAILLM(**kwargs["LLM"])
|
129 |
+
if "LLM" in kwargs
|
130 |
+
else OpenAILLM(model = "gpt-3.5-turbo-16k-0613",temperature=0.3,log_path=log_path)
|
131 |
+
)
|
132 |
+
return LLM
|
133 |
+
|
agents/Memory/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base_Memory import Memory
|
agents/Memory/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (158 Bytes). View file
|
|
agents/Memory/__pycache__/base_Memory.cpython-38.pyc
ADDED
Binary file (1.43 kB). View file
|
|
agents/Memory/base_Memory.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Prompt import *
|
2 |
+
class Memory:
|
3 |
+
def __init__(self,role,name,content) -> None:
|
4 |
+
self.send_role = role
|
5 |
+
self.send_name = name
|
6 |
+
self.content = content
|
7 |
+
|
8 |
+
def get_gpt_message(self,role):
|
9 |
+
return {"role":role,"content":self.content}
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def get_chat_history(self,messages,agent_name =None):
|
13 |
+
"""
|
14 |
+
Splice a memory list into a sentence
|
15 |
+
input :
|
16 |
+
messages(list) : list of memory(Memory)
|
17 |
+
Return :
|
18 |
+
chat_history(str) : One sentence after integration
|
19 |
+
"""
|
20 |
+
chat_history = ""
|
21 |
+
for message in messages:
|
22 |
+
name,role,content = message.send_name,message.send_role,message.content
|
23 |
+
if agent_name and agent_name==name:
|
24 |
+
name = "you"
|
25 |
+
chat_history += eval(Single_message)
|
26 |
+
chat_history = eval(Chat_total_message)
|
27 |
+
return chat_history
|
28 |
+
|
29 |
+
def get_query(self):
|
30 |
+
"Return : query(str):last sentence"
|
31 |
+
name,role,content = self.send_name,self.send_role,self.content
|
32 |
+
return eval(Single_message)
|
agents/Prompt/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base_Prompts import *
|
agents/Prompt/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (145 Bytes). View file
|
|
agents/Prompt/__pycache__/base_Prompts.cpython-38.pyc
ADDED
Binary file (3.43 kB). View file
|
|
agents/Prompt/base_Prompts.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# SOP========================================================================================================
|
3 |
+
# "environment_prompt"
|
4 |
+
# current_state , self(sop)
|
5 |
+
Get_environment_prompt = "f\"The current scenario is as follows <environment> {self.current_state.environment_prompt} </environment>\""
|
6 |
+
|
7 |
+
|
8 |
+
# sop.transit
|
9 |
+
#================================================================
|
10 |
+
Transit_system_prompt = "f\"{environment_prompt};{judge_system_prompt}\""
|
11 |
+
|
12 |
+
# transit chat message
|
13 |
+
# "environment_prompt" is get from "Get_environment_prompt" ; "chat_history_message" if from Memory
|
14 |
+
Transit_message = "f\"{environment_summary};The chat history is as follows:\\n<chat> {chat_history_message}\\n</chat>;You especially need to pay attention to the last query<query>\\n{query}\\n</query> and the relevant conversation <relevant>\\n{relevant_history} \\n</relevant>\\n\""
|
15 |
+
|
16 |
+
|
17 |
+
Transit_last_prompt = "f\"{judge_last_prompt}\""
|
18 |
+
#sop.transit================================================================
|
19 |
+
|
20 |
+
# sop.call
|
21 |
+
#================================================================
|
22 |
+
# help controller to determine the next role to speak.(the {} is agent role) call_prompt + allocate_component
|
23 |
+
Allocate_component = "f\"If it's currently supposed to be speaking for {role}, then output <end>{role}</end>.\\n\""
|
24 |
+
|
25 |
+
# environment_prompt is get from "Get_environment_prompt" ; "chat_history_message" if from Memory
|
26 |
+
Call_system_prompt = "f\"{environment_prompt};{call_system_prompt};{allocate_prompt}\""
|
27 |
+
|
28 |
+
#
|
29 |
+
Call_last_prompt = "f\"You especially need to pay attention to the last query<query>\\n{query}\\n</query> and the relevant conversation <relevant>\\n{relevant_history} \\n</relevant>\\n;Now please choose the person to speak according to the following rules :{allocate_prompt};Note: The person whose turn it is now cannot be the same as the person who spoke last time, so {last_name} cannot be output\\n.\""
|
30 |
+
|
31 |
+
Call_message = "f\"The chat history is as follows:\\n<history>\\n{chat_history_message}</history>\\n;The last person to speak is: {last_name}\\n. \""
|
32 |
+
#sop.call================================================================
|
33 |
+
# SOP========================================================================================================
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
# Memory========================================================================================================
|
41 |
+
Single_message = "f\"{name} said that :{content}\""
|
42 |
+
|
43 |
+
Chat_total_message = "f\"{chat_history}\""
|
44 |
+
# Memory========================================================================================================
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
# Environment========================================================================================================
|
52 |
+
Default_environment_summary_system_prompt = "\"\\nYour task is to summarize the historical dialogue records according to the current scene, and summarize the most important information\""
|
53 |
+
|
54 |
+
Default_environment_summary_last_prompt = "\"Please make a summary based on the historical chat records, the output format is history summary: \{your summary content\} \""
|
55 |
+
|
56 |
+
Environment_summary_memory = "f\"The information you need to know is as follows:\\n</information>\\n\
|
57 |
+
The summary of the previous dialogue history is:<summary>\\n{summary}\\n.</summary>\
|
58 |
+
The latest conversation record is as follows:\\n<hisroty> {chat_history}\\n</history>,\
|
59 |
+
the relevant chat history you may need is:<relevant>{relevant_history}</relevant>\""
|
60 |
+
|
61 |
+
Environment_summary_system_prompt = "f\"{environment_prompt};{current_memory};{summary_system_prompt};\""
|
62 |
+
|
63 |
+
|
64 |
+
# observe
|
65 |
+
Agent_observe_relevant_memory = "f\"The relevant chat history are as follows:\\n<relevant_history>{relevant_memory} </relevant_history>\\n\""
|
66 |
+
|
67 |
+
|
68 |
+
Agent_observe_memory = "f\"Here's what you need to know(Remember, this is just information, Try not to repeat what's inside):\\n<information>\\n{relevant_memory};\
|
69 |
+
The previous summary of chat history is as follows :<summary>\\n{agent.short_term_memory}\\n</summary>.\
|
70 |
+
The new chat history is as follows:\\n<history> {conversations}\\n</history>\\n\
|
71 |
+
</information>\""
|
72 |
+
# Environment========================================================================================================
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
# Agent========================================================================================================
|
78 |
+
Agent_summary_system_prompt = "f\"{summary_prompt};Please summarize past key summary \\n<summary>\\n {self.short_term_memory} </summary>and new chat_history as follows: <history>\\n{conversations}</history>\""
|
79 |
+
|
80 |
+
Agent_last_prompt = "f\"{last_prompt};\\nPlease continue the talk based on your known information,Make an effort to make the conversation more coherent and try to respond differently from your existing knowledge, avoiding repeating what others have said.\""
|
81 |
+
|
82 |
+
Agent_system_prompt = "f\"{system_prompt},\""
|
83 |
+
# Agent========================================================================================================
|
agents/SOP.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The AIWaves Inc. team.
|
3 |
+
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""standard operation procedure of an LLM Autonomous agent"""
|
17 |
+
import random
|
18 |
+
from LLM.base_LLM import *
|
19 |
+
from State import State
|
20 |
+
from utils import extract, get_relevant_history
|
21 |
+
from Memory import Memory
|
22 |
+
from Prompt import *
|
23 |
+
import json
|
24 |
+
import os
|
25 |
+
|
26 |
+
class SOP:
|
27 |
+
"""
|
28 |
+
Responsible for managing the operational processes of all agents
|
29 |
+
"""
|
30 |
+
|
31 |
+
# SOP should have args : "states" "relations" "root"
|
32 |
+
|
33 |
+
def __init__(self, **kwargs):
|
34 |
+
self.controller_dict = {}
|
35 |
+
self.LLM = init_LLM("logs/god",**kwargs)
|
36 |
+
|
37 |
+
self.states = {}
|
38 |
+
self.init_states(kwargs["states"])
|
39 |
+
self.init_relation(kwargs["relations"])
|
40 |
+
for state_name, states_dict in kwargs["states"].items():
|
41 |
+
if state_name != "end_state" and "controller" in states_dict:
|
42 |
+
self.controller_dict[state_name] = states_dict["controller"]
|
43 |
+
|
44 |
+
self.user_names = kwargs["user_names"] if "user_names" in kwargs else []
|
45 |
+
self.root = self.states[kwargs["root"]]
|
46 |
+
self.current_state = self.root
|
47 |
+
self.finish_state_name = (
|
48 |
+
kwargs["finish_state_name"]
|
49 |
+
if "finish_state_name" in kwargs
|
50 |
+
else "end_state"
|
51 |
+
)
|
52 |
+
self.roles_to_names = None
|
53 |
+
self.names_to_roles = None
|
54 |
+
self.finished = False
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def from_config(cls, config_path):
|
58 |
+
with open(config_path) as f:
|
59 |
+
config = json.load(f)
|
60 |
+
os.environ.clear()
|
61 |
+
for key,value in config["config"].items():
|
62 |
+
if key == "API_BASE":
|
63 |
+
if value == "":
|
64 |
+
pass
|
65 |
+
else:
|
66 |
+
os.environ[key] = value
|
67 |
+
# assert "API_KEY" in os.environ and os.environ["API_KEY"] != "API_KEY","Please go to config.json to set API_KEY"
|
68 |
+
|
69 |
+
sop = SOP(**config)
|
70 |
+
return sop
|
71 |
+
|
72 |
+
def init_states(self, states_dict):
|
73 |
+
for state_name, state_dict in states_dict.items():
|
74 |
+
state_dict["name"] = state_name
|
75 |
+
self.states[state_name] = State(**state_dict)
|
76 |
+
|
77 |
+
def init_relation(self, relations):
|
78 |
+
for state_name, state_relation in relations.items():
|
79 |
+
for idx, next_state_name in state_relation.items():
|
80 |
+
self.states[state_name].next_states[idx] = self.states[next_state_name]
|
81 |
+
|
82 |
+
def transit(self, chat_history, **kwargs):
|
83 |
+
"""
|
84 |
+
Determine the next state based on the current situation
|
85 |
+
Return :
|
86 |
+
next_state(State) : the next state
|
87 |
+
"""
|
88 |
+
# 如果是单一循环节点,则一直循环即可
|
89 |
+
# If it is a single loop node, just keep looping
|
90 |
+
if len(self.current_state.next_states) == 1:
|
91 |
+
next_state = "0"
|
92 |
+
|
93 |
+
# 否则则需要controller去判断进入哪一节点
|
94 |
+
# Otherwise, the controller needs to determine which node to enter.
|
95 |
+
else:
|
96 |
+
current_state = self.current_state
|
97 |
+
controller_dict = self.controller_dict[current_state.name]
|
98 |
+
relevant_history = kwargs["relevant_history"]
|
99 |
+
|
100 |
+
max_chat_nums = controller_dict["max_chat_nums"] if "max_chat_nums" in controller_dict else 1000
|
101 |
+
if current_state.chat_nums>=max_chat_nums:
|
102 |
+
return self.current_state.next_states["1"]
|
103 |
+
|
104 |
+
|
105 |
+
# 否则则让controller判断是否结束
|
106 |
+
# Otherwise, let the controller judge whether to end
|
107 |
+
judge_system_prompt = controller_dict["judge_system_prompt"]
|
108 |
+
environment_prompt = eval(Get_environment_prompt) if current_state.environment_prompt else ""
|
109 |
+
transit_system_prompt = eval(Transit_system_prompt)
|
110 |
+
|
111 |
+
judge_last_prompt = controller_dict["judge_last_prompt"]
|
112 |
+
transit_last_prompt = eval(Transit_last_prompt)
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
environment = kwargs["environment"]
|
117 |
+
environment_summary = environment.shared_memory["short_term_memory"]
|
118 |
+
chat_history_message = Memory.get_chat_history(chat_history)
|
119 |
+
query = chat_history[-1].get_query()
|
120 |
+
|
121 |
+
chat_messages = [
|
122 |
+
{
|
123 |
+
"role": "user",
|
124 |
+
"content": eval(Transit_message)
|
125 |
+
}
|
126 |
+
]
|
127 |
+
|
128 |
+
extract_words = controller_dict["judge_extract_words"] if "judge_extract_words" in controller_dict else "end"
|
129 |
+
|
130 |
+
|
131 |
+
response = self.LLM.get_response(
|
132 |
+
chat_messages, transit_system_prompt, transit_last_prompt, stream=False, **kwargs
|
133 |
+
)
|
134 |
+
next_state = (
|
135 |
+
response if response.isdigit() else extract(response, extract_words)
|
136 |
+
)
|
137 |
+
|
138 |
+
# 如果没有parse出来则继续循环
|
139 |
+
# If no parse comes out, continue looping
|
140 |
+
if not next_state.isdigit():
|
141 |
+
next_state = "0"
|
142 |
+
|
143 |
+
next_state = self.current_state.next_states[next_state]
|
144 |
+
return next_state
|
145 |
+
|
146 |
+
|
147 |
+
def route(self, chat_history, **kwargs):
|
148 |
+
"""
|
149 |
+
Determine the role that needs action based on the current situation
|
150 |
+
Return :
|
151 |
+
current_agent(Agent) : the next act agent
|
152 |
+
"""
|
153 |
+
|
154 |
+
agents = kwargs["agents"]
|
155 |
+
|
156 |
+
# 知道进入哪一状态后开始分配角色,如果该状态下只有一个角色则直接分配给他
|
157 |
+
# Start assigning roles after knowing which state you have entered. If there is only one role in that state, assign it directly to him.
|
158 |
+
if len(self.current_state.roles) == 1:
|
159 |
+
next_role = self.current_state.roles[0]
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
# 否则controller进行分配
|
164 |
+
# Otherwise the controller determines
|
165 |
+
else:
|
166 |
+
relevant_history = kwargs["relevant_history"]
|
167 |
+
controller_type = (
|
168 |
+
self.controller_dict[self.current_state.name]["controller_type"]
|
169 |
+
if "controller_type" in self.controller_dict[self.current_state.name]
|
170 |
+
else "order"
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
# 如果是rule 控制器,则交由LLM进行分配角色
|
175 |
+
# If controller type is rule, it is left to LLM to assign roles.
|
176 |
+
if controller_type == "rule":
|
177 |
+
controller_dict = self.controller_dict[self.current_state.name]
|
178 |
+
|
179 |
+
call_last_prompt = controller_dict["call_last_prompt"] if "call_last_prompt" in controller_dict else ""
|
180 |
+
|
181 |
+
allocate_prompt = ""
|
182 |
+
roles = list(set(self.current_state.roles))
|
183 |
+
for role in roles:
|
184 |
+
allocate_prompt += eval(Allocate_component)
|
185 |
+
|
186 |
+
call_system_prompt = controller_dict["call_system_prompt"] if "call_system_prompt" in controller_dict else ""
|
187 |
+
environment_prompt = eval(Get_environment_prompt) if self.current_state.environment_prompt else ""
|
188 |
+
# call_system_prompt + environment + allocate_prompt
|
189 |
+
call_system_prompt = eval(Call_system_prompt)
|
190 |
+
|
191 |
+
query = chat_history[-1].get_query()
|
192 |
+
last_name = chat_history[-1].send_name
|
193 |
+
# last_prompt: note + last_prompt + query
|
194 |
+
call_last_prompt =eval(Call_last_prompt)
|
195 |
+
|
196 |
+
|
197 |
+
chat_history_message = Memory.get_chat_history(chat_history)
|
198 |
+
# Intermediate historical conversation records
|
199 |
+
chat_messages = [
|
200 |
+
{
|
201 |
+
"role": "user",
|
202 |
+
"content": eval(Call_message),
|
203 |
+
}
|
204 |
+
]
|
205 |
+
|
206 |
+
extract_words = controller_dict["call_extract_words"] if "call_extract_words" in controller_dict else "end"
|
207 |
+
|
208 |
+
response = self.LLM.get_response(
|
209 |
+
chat_messages, call_system_prompt, call_last_prompt, stream=False, **kwargs
|
210 |
+
)
|
211 |
+
|
212 |
+
# get next role
|
213 |
+
next_role = extract(response, extract_words)
|
214 |
+
|
215 |
+
# Speak in order
|
216 |
+
elif controller_type == "order":
|
217 |
+
# If there is no begin role, it will be given directly to the first person.
|
218 |
+
if not self.current_state.current_role:
|
219 |
+
next_role = self.current_state.roles[0]
|
220 |
+
# otherwise first
|
221 |
+
else:
|
222 |
+
self.current_state.index += 1
|
223 |
+
self.current_state.index = (self.current_state.index) % len(self.current_state.roles)
|
224 |
+
next_role = self.current_state.roles[self.current_state.index]
|
225 |
+
# random speak
|
226 |
+
elif controller_type == "random":
|
227 |
+
next_role = random.choice(self.current_state.roles)
|
228 |
+
|
229 |
+
# 如果下一角色不在,则随机挑选一个
|
230 |
+
# If the next character is not available, pick one at random
|
231 |
+
if next_role not in self.current_state.roles:
|
232 |
+
next_role = random.choice(self.current_state.roles)
|
233 |
+
|
234 |
+
self.current_state.current_role = next_role
|
235 |
+
|
236 |
+
next_agent = agents[self.roles_to_names[self.current_state.name][next_role]]
|
237 |
+
|
238 |
+
return next_agent
|
239 |
+
|
240 |
+
def next(self, environment, agents):
|
241 |
+
"""
|
242 |
+
Determine the next state and the agent that needs action based on the current situation
|
243 |
+
"""
|
244 |
+
|
245 |
+
# 如��是第一次进入该状态
|
246 |
+
# If it is the first time to enter this state
|
247 |
+
|
248 |
+
if self.current_state.is_begin:
|
249 |
+
agent_name = self.roles_to_names[self.current_state.name][self.current_state.begin_role]
|
250 |
+
agent = agents[agent_name]
|
251 |
+
return self.current_state,agent
|
252 |
+
|
253 |
+
|
254 |
+
# get relevant history
|
255 |
+
query = environment.shared_memory["long_term_memory"][-1].content
|
256 |
+
relevant_history = get_relevant_history(
|
257 |
+
query,
|
258 |
+
environment.shared_memory["long_term_memory"][:-1],
|
259 |
+
environment.shared_memory["chat_embeddings"][:-1],
|
260 |
+
)
|
261 |
+
relevant_history = Memory.get_chat_history(relevant_history)
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
next_state = self.transit(
|
266 |
+
chat_history=environment.shared_memory["long_term_memory"][
|
267 |
+
environment.current_chat_history_idx :
|
268 |
+
],
|
269 |
+
relevant_history=relevant_history,
|
270 |
+
environment=environment,
|
271 |
+
)
|
272 |
+
# 如果进入终止节点,则直接终止
|
273 |
+
# If you enter the termination node, terminate directly
|
274 |
+
if next_state.name == self.finish_state_name:
|
275 |
+
self.finished = True
|
276 |
+
return None, None
|
277 |
+
|
278 |
+
self.current_state = next_state
|
279 |
+
|
280 |
+
# 如果是首次进入该节点且有开场白,则直接分配给开场角色
|
281 |
+
# If it is the first time to enter the state and there is a begin query, it will be directly assigned to the begin role.
|
282 |
+
if self.current_state.is_begin and self.current_state.begin_role:
|
283 |
+
agent_name = self.roles_to_names[self.current_state.name][self.current_state.begin_role]
|
284 |
+
agent = agents[agent_name]
|
285 |
+
return self.current_state,agent
|
286 |
+
|
287 |
+
|
288 |
+
next_agent = self.route(
|
289 |
+
chat_history=environment.shared_memory["long_term_memory"][
|
290 |
+
environment.current_chat_history_idx :
|
291 |
+
],
|
292 |
+
agents = agents,
|
293 |
+
relevant_history=relevant_history,
|
294 |
+
)
|
295 |
+
|
296 |
+
return self.current_state, next_agent
|
agents/State.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Component import *
|
2 |
+
|
3 |
+
|
4 |
+
class State:
|
5 |
+
"""
|
6 |
+
Sub-scenes of role activities, responsible for storing the tasks that each role needs to do
|
7 |
+
"""
|
8 |
+
def __init__(self, **kwargs):
|
9 |
+
self.next_states = {}
|
10 |
+
self.name = kwargs["name"]
|
11 |
+
|
12 |
+
self.environment_prompt = (
|
13 |
+
kwargs["environment_prompt"] if "environment_prompt" in kwargs else ""
|
14 |
+
)
|
15 |
+
|
16 |
+
self.roles = kwargs["roles"] if "roles" in kwargs else (list(kwargs["agent_states"].keys()) if "agent_states" in kwargs else [0])
|
17 |
+
if len(self.roles) == 0:
|
18 |
+
self.roles = [0]
|
19 |
+
self.begin_role = (
|
20 |
+
kwargs["begin_role"] if "begin_role" in kwargs else self.roles[0]
|
21 |
+
)
|
22 |
+
self.begin_query = kwargs["begin_query"] if "begin_query" in kwargs else None
|
23 |
+
|
24 |
+
self.is_begin = True
|
25 |
+
|
26 |
+
self.summary_prompt = (
|
27 |
+
kwargs["summary_prompt"] if "summary_prompt" in kwargs else None
|
28 |
+
)
|
29 |
+
self.current_role = self.begin_role
|
30 |
+
self.components = (
|
31 |
+
self.init_components(kwargs["agent_states"])
|
32 |
+
if "agent_states" in kwargs
|
33 |
+
else {}
|
34 |
+
)
|
35 |
+
self.index = (
|
36 |
+
self.roles.index(self.begin_role) if self.begin_role in self.roles else 0
|
37 |
+
)
|
38 |
+
self.chat_nums = 0
|
39 |
+
|
40 |
+
def init_components(self, agent_states_dict: dict):
|
41 |
+
agent_states = {}
|
42 |
+
for role, components in agent_states_dict.items():
|
43 |
+
component_dict = {}
|
44 |
+
for component, component_args in components.items():
|
45 |
+
if component:
|
46 |
+
# "role" "style"
|
47 |
+
if component == "style":
|
48 |
+
component_dict["style"] = StyleComponent(component_args["role"])
|
49 |
+
|
50 |
+
# "task"
|
51 |
+
elif component == "task":
|
52 |
+
component_dict["task"] = TaskComponent(component_args["task"])
|
53 |
+
|
54 |
+
# "rule"
|
55 |
+
elif component == "rule":
|
56 |
+
component_dict["rule"] = RuleComponent(component_args["rule"])
|
57 |
+
|
58 |
+
# "demonstration"
|
59 |
+
elif component == "demonstrations":
|
60 |
+
component_dict["demonstrations"] = DemonstrationComponent(
|
61 |
+
component_args["demonstrations"]
|
62 |
+
)
|
63 |
+
|
64 |
+
# "output"
|
65 |
+
elif component == "output":
|
66 |
+
component_dict["output"] = OutputComponent(
|
67 |
+
component_args["output"]
|
68 |
+
)
|
69 |
+
|
70 |
+
elif component == "last":
|
71 |
+
component_dict["last"] = LastComponent(
|
72 |
+
component_args["last_prompt"]
|
73 |
+
)
|
74 |
+
|
75 |
+
# "demonstrations"
|
76 |
+
elif component == "cot":
|
77 |
+
component_dict["cot"] = CoTComponent(
|
78 |
+
component_args["demonstrations"]
|
79 |
+
)
|
80 |
+
elif component == "CustomizeComponent":
|
81 |
+
component_dict["CustomizeComponent"] = CustomizeComponent(
|
82 |
+
component_args["template"], component_args["keywords"]
|
83 |
+
)
|
84 |
+
|
85 |
+
elif component == "system" :
|
86 |
+
component_dict["system"] = SystemComponent(
|
87 |
+
component_args["system_prompt"]
|
88 |
+
)
|
89 |
+
|
90 |
+
# =================================================================================#
|
91 |
+
|
92 |
+
# "output"
|
93 |
+
elif component == "StaticComponent":
|
94 |
+
component_dict["StaticComponent"] = StaticComponent(
|
95 |
+
component_args["output"]
|
96 |
+
)
|
97 |
+
|
98 |
+
# "top_k" "type" "knowledge_base" "system_prompt" "last_prompt"
|
99 |
+
elif component == "KnowledgeBaseComponent":
|
100 |
+
component_dict["tool"] = KnowledgeBaseComponent(
|
101 |
+
component_args["top_k"],
|
102 |
+
component_args["type"],
|
103 |
+
component_args["knowledge_path"],
|
104 |
+
)
|
105 |
+
|
106 |
+
elif component == "CategoryRequirementsComponent":
|
107 |
+
component_dict[
|
108 |
+
"CategoryRequirementsComponent"
|
109 |
+
] = CategoryRequirementsComponent(
|
110 |
+
component_args["information_path"]
|
111 |
+
)
|
112 |
+
|
113 |
+
elif component == "FunctionComponent":
|
114 |
+
component_dict["FunctionComponent"] = FunctionComponent(component_args[""])
|
115 |
+
# "short_memory_extract_words" "long_memory_extract_words" "system_prompt" "last_prompt"
|
116 |
+
elif component == "ExtractComponent":
|
117 |
+
component_dict["ExtractComponent"] = ExtractComponent(
|
118 |
+
component_args["extract_words"],
|
119 |
+
component_args["system_prompt"],
|
120 |
+
component_args["last_prompt"],
|
121 |
+
)
|
122 |
+
elif component == "WebSearchComponent":
|
123 |
+
component_dict["WebSearchComponent"] = WebSearchComponent(
|
124 |
+
component_args["engine_name"], component_args["api"]
|
125 |
+
)
|
126 |
+
elif component == "WebCrawlComponent":
|
127 |
+
component_dict["WebCrawlComponent"] = WebCrawlComponent(
|
128 |
+
component_args["name"]
|
129 |
+
)
|
130 |
+
|
131 |
+
elif component == "CodeComponent":
|
132 |
+
component_dict["CodeComponent"] = CodeComponent(
|
133 |
+
component_args["file_name"], component_args["keyword"]
|
134 |
+
)
|
135 |
+
|
136 |
+
# ====================================================
|
137 |
+
else:
|
138 |
+
continue
|
139 |
+
|
140 |
+
agent_states[role] = component_dict
|
141 |
+
|
142 |
+
return agent_states
|
agents/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .evolve import *
|
2 |
+
from .SOP import *
|
3 |
+
from .State import *
|
4 |
+
from .utils import *
|
agents/__pycache__/SOP.cpython-38.pyc
ADDED
Binary file (5.41 kB). View file
|
|
agents/__pycache__/State.cpython-38.pyc
ADDED
Binary file (2.58 kB). View file
|
|
agents/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (237 Bytes). View file
|
|
agents/__pycache__/evolve.cpython-38.pyc
ADDED
Binary file (217 Bytes). View file
|
|
agents/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (13.1 kB). View file
|
|
agents/evolve.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The AIWaves Inc. team.
|
3 |
+
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""self evolution of an LLM autonoumous agent"""
|
agents/template.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## default { "temperature": 0.3, "model": "gpt-3.5-turbo-16k-0613","log_path": "logs/{your name}"}
|
2 |
+
LLM = {
|
3 |
+
"temperature": 0.0,
|
4 |
+
"model": "gpt-3.5-turbo-16k-0613",
|
5 |
+
"log_path": "logs/god"
|
6 |
+
}
|
7 |
+
|
8 |
+
|
9 |
+
Agents = {
|
10 |
+
"Lilong" : {
|
11 |
+
"style" : "professional",
|
12 |
+
"roles" : {
|
13 |
+
"company" : "coder",
|
14 |
+
"state2" : "role2",
|
15 |
+
},
|
16 |
+
"name2" : {
|
17 |
+
"style" : "professional",
|
18 |
+
"roles" : {
|
19 |
+
"company" : "coder",
|
20 |
+
"state2" : "role2",
|
21 |
+
},
|
22 |
+
}
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
# indispensable parameter: "controller_type"("order","random","rule")
|
27 |
+
# default extract words: "end". You can choose not to fill in this parameter
|
28 |
+
controller = {
|
29 |
+
"controller_type": "order",
|
30 |
+
"max_chat_nums" : 12,
|
31 |
+
"judge_system_prompt": "",
|
32 |
+
"judge_last_prompt": "",
|
33 |
+
"judge_extract_words": "end",
|
34 |
+
"call_system_prompt" : "",
|
35 |
+
"call_last_prompt": "",
|
36 |
+
"call_extract_words": ""
|
37 |
+
}
|
38 |
+
|
39 |
+
#
|
40 |
+
Agent_state = {
|
41 |
+
"role": {
|
42 |
+
"LLM_type": "OpenAI",
|
43 |
+
"LLM": LLM,
|
44 |
+
"style": {
|
45 |
+
"role": "Opening Advocate for the Affirmative",
|
46 |
+
"style": "professional"
|
47 |
+
},
|
48 |
+
"task": {
|
49 |
+
"task": ""
|
50 |
+
},
|
51 |
+
"rule": {
|
52 |
+
"rule": ""
|
53 |
+
}
|
54 |
+
},
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
# indispensable parameter: "agent_states","controller"
|
59 |
+
# "roles" determines the speaking order when the rule is order. If not set, it is the default order.
|
60 |
+
# "begin_query" & "begin_role" determines the first speaker.It often determines the direction of the next speech. If you do not set it, it will default to the first agent.
|
61 |
+
# "environment_prompt" : Responsible for setting the scene for the current environment
|
62 |
+
State = {
|
63 |
+
"controller": controller,
|
64 |
+
"begin_role": "",
|
65 |
+
"begin_query": "",
|
66 |
+
"environment_prompt": "",
|
67 |
+
"roles": ["role1","role2"],
|
68 |
+
"LLM_type": "OpenAI",
|
69 |
+
"LLM": LLM,
|
70 |
+
"agent_state" : Agent_state,
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
States = {
|
76 |
+
"end_state":{
|
77 |
+
"agent_states":{}
|
78 |
+
},
|
79 |
+
"state1" : State
|
80 |
+
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
# default finish_state_name is "end_state"
|
85 |
+
# "environment_type" : "competive" : different states not share the memory; "cooperative":diffrent states share the memory
|
86 |
+
SOP = {
|
87 |
+
"config" : {
|
88 |
+
"API_KEY" : "Your key",
|
89 |
+
"PROXY" : "Your PROXY",
|
90 |
+
"MAX_CHAT_HISTORY" : "5",
|
91 |
+
"User_Names" : "[\"alexander\"]"
|
92 |
+
},
|
93 |
+
"environment_type" : "competive",
|
94 |
+
"LLM_type": "OpenAI",
|
95 |
+
"LLM" :LLM,
|
96 |
+
"root": "state1",
|
97 |
+
"finish_state_name" : "end_state",
|
98 |
+
"relations": {
|
99 |
+
"state1": {
|
100 |
+
"0": "state1",
|
101 |
+
"1": "state2"
|
102 |
+
},
|
103 |
+
"state2":{
|
104 |
+
"0":"state2",
|
105 |
+
"1":"end_state"
|
106 |
+
}
|
107 |
+
},
|
108 |
+
"agents": Agents,
|
109 |
+
"states": States,
|
110 |
+
}
|
111 |
+
|
agents/utils.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The AIWaves Inc. team.
|
3 |
+
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""helper functions for an LLM autonoumous agent"""
|
17 |
+
import csv
|
18 |
+
import random
|
19 |
+
import json
|
20 |
+
import pandas
|
21 |
+
import numpy as np
|
22 |
+
import requests
|
23 |
+
import torch
|
24 |
+
from tqdm import tqdm
|
25 |
+
from text2vec import semantic_search
|
26 |
+
import re
|
27 |
+
import datetime
|
28 |
+
from langchain.document_loaders import UnstructuredFileLoader
|
29 |
+
from langchain.text_splitter import CharacterTextSplitter
|
30 |
+
from sentence_transformers import SentenceTransformer
|
31 |
+
import string
|
32 |
+
import random
|
33 |
+
import os
|
34 |
+
import openai
|
35 |
+
|
36 |
+
embed_model_name = os.environ["Embed_Model"] if "Embed_Model" in os.environ else "text-embedding-ada-002"
|
37 |
+
if embed_model_name in ["text-embedding-ada-002"]:
|
38 |
+
pass
|
39 |
+
else:
|
40 |
+
embedding_model = SentenceTransformer(
|
41 |
+
embed_model_name, device=torch.device("cpu")
|
42 |
+
)
|
43 |
+
|
44 |
+
def get_embedding(sentence):
|
45 |
+
if embed_model_name in ["text-embedding-ada-002"]:
|
46 |
+
openai.api_key = os.environ["API_KEY"]
|
47 |
+
if "PROXY" in os.environ:
|
48 |
+
assert "http:" in os.environ["PROXY"] or "socks" in os.environ["PROXY"],"PROXY error,PROXY must be http or socks"
|
49 |
+
openai.proxy = os.environ["PROXY"]
|
50 |
+
if "API_BASE" in os.environ:
|
51 |
+
openai.api_base = os.environ["API_BASE"]
|
52 |
+
embedding_model = openai.Embedding
|
53 |
+
embed = embedding_model.create(
|
54 |
+
model=embed_model_name,
|
55 |
+
input=sentence
|
56 |
+
)
|
57 |
+
embed = embed["data"][0]["embedding"]
|
58 |
+
embed = torch.tensor(embed,dtype=torch.float32)
|
59 |
+
else:
|
60 |
+
embed = embedding_model.encode(sentence,convert_to_tensor=True)
|
61 |
+
if len(embed.shape)==1:
|
62 |
+
embed = embed.unsqueeze(0)
|
63 |
+
return embed
|
64 |
+
|
65 |
+
|
66 |
+
def get_code():
|
67 |
+
return "".join(random.sample(string.ascii_letters + string.digits, 8))
|
68 |
+
|
69 |
+
|
70 |
+
def get_content_between_a_b(start_tag, end_tag, text):
|
71 |
+
"""
|
72 |
+
|
73 |
+
Args:
|
74 |
+
start_tag (str): start_tag
|
75 |
+
end_tag (str): end_tag
|
76 |
+
text (str): complete sentence
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
str: the content between start_tag and end_tag
|
80 |
+
"""
|
81 |
+
extracted_text = ""
|
82 |
+
start_index = text.find(start_tag)
|
83 |
+
while start_index != -1:
|
84 |
+
end_index = text.find(end_tag, start_index + len(start_tag))
|
85 |
+
if end_index != -1:
|
86 |
+
extracted_text += text[start_index +
|
87 |
+
len(start_tag):end_index] + " "
|
88 |
+
start_index = text.find(start_tag, end_index + len(end_tag))
|
89 |
+
else:
|
90 |
+
break
|
91 |
+
|
92 |
+
return extracted_text.strip()
|
93 |
+
|
94 |
+
|
95 |
+
def extract(text, type):
|
96 |
+
"""extract the content between <type></type>
|
97 |
+
|
98 |
+
Args:
|
99 |
+
text (str): complete sentence
|
100 |
+
type (str): tag
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
str: content between <type></type>
|
104 |
+
"""
|
105 |
+
target_str = get_content_between_a_b(f"<{type}>", f"</{type}>", text)
|
106 |
+
return target_str
|
107 |
+
|
108 |
+
def count_files_in_directory(directory):
|
109 |
+
# 获取指定目录下的文件数目
|
110 |
+
file_count = len([f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))])
|
111 |
+
return file_count
|
112 |
+
|
113 |
+
def delete_oldest_files(directory, num_to_keep):
|
114 |
+
# 获取目录下文件列表,并按修改时间排序
|
115 |
+
files = [(f, os.path.getmtime(os.path.join(directory, f))) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
|
116 |
+
|
117 |
+
# 删除最开始的 num_to_keep 个文件
|
118 |
+
for i in range(min(num_to_keep, len(files))):
|
119 |
+
file_to_delete = os.path.join(directory, files[i][0])
|
120 |
+
os.remove(file_to_delete)
|
121 |
+
|
122 |
+
def delete_files_if_exceed_threshold(directory, threshold, num_to_keep):
|
123 |
+
# 获取文件数目并进行处理
|
124 |
+
file_count = count_files_in_directory(directory)
|
125 |
+
if file_count > threshold:
|
126 |
+
delete_count = file_count - num_to_keep
|
127 |
+
delete_oldest_files(directory, delete_count)
|
128 |
+
|
129 |
+
def save_logs(log_path, messages, response):
|
130 |
+
if not os.path.exists(log_path):
|
131 |
+
os.mkdir(log_path)
|
132 |
+
delete_files_if_exceed_threshold(log_path, 20, 10)
|
133 |
+
log_path = log_path if log_path else "logs"
|
134 |
+
log = {}
|
135 |
+
log["input"] = messages
|
136 |
+
log["output"] = response
|
137 |
+
os.makedirs(log_path, exist_ok=True)
|
138 |
+
log_file = os.path.join(
|
139 |
+
log_path,
|
140 |
+
datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + ".json")
|
141 |
+
with open(log_file, "w", encoding="utf-8") as f:
|
142 |
+
json.dump(log, f, ensure_ascii=False, indent=2)
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
def semantic_search_word2vec(query_embedding, kb_embeddings, top_k):
|
147 |
+
return semantic_search(query_embedding, kb_embeddings, top_k=top_k)
|
148 |
+
|
149 |
+
|
150 |
+
def cut_sent(para):
|
151 |
+
para = re.sub("([。!?\?])([^”’])", r"\1\n\2", para)
|
152 |
+
para = re.sub("(\.{6})([^”’])", r"\1\n\2", para)
|
153 |
+
para = re.sub("(\…{2})([^”’])", r"\1\n\2", para)
|
154 |
+
para = re.sub("([。!?\?][”’])([^,。!?\?])", r"\1\n\2", para)
|
155 |
+
para = para.rstrip()
|
156 |
+
pieces = [i for i in para.split("\n") if i]
|
157 |
+
batch_size = 3
|
158 |
+
chucks = [
|
159 |
+
" ".join(pieces[i:i + batch_size])
|
160 |
+
for i in range(0, len(pieces), batch_size)
|
161 |
+
]
|
162 |
+
return chucks
|
163 |
+
|
164 |
+
|
165 |
+
def process_document(file_path):
|
166 |
+
"""
|
167 |
+
Save QA_csv to json.
|
168 |
+
Args:
|
169 |
+
model: LLM to generate embeddings
|
170 |
+
qa_dict: A dict contains Q&A
|
171 |
+
save_path: where to save the json file.
|
172 |
+
Json format:
|
173 |
+
Dict[num,Dict[q:str,a:str,chunk:str,emb:List[float]]
|
174 |
+
"""
|
175 |
+
final_dict = {}
|
176 |
+
count = 0
|
177 |
+
if file_path.endswith(".csv"):
|
178 |
+
dataset = pandas.read_csv(file_path)
|
179 |
+
questions = dataset["question"]
|
180 |
+
answers = dataset["answer"]
|
181 |
+
# embedding q+chunk
|
182 |
+
for q, a in zip(questions, answers):
|
183 |
+
for text in cut_sent(a):
|
184 |
+
temp_dict = {}
|
185 |
+
temp_dict["q"] = q
|
186 |
+
temp_dict["a"] = a
|
187 |
+
temp_dict["chunk"] = text
|
188 |
+
temp_dict["emb"] = get_embedding(q + text).tolist()
|
189 |
+
final_dict[count] = temp_dict
|
190 |
+
count += 1
|
191 |
+
# embedding chunk
|
192 |
+
for q, a in zip(questions, answers):
|
193 |
+
for text in cut_sent(a):
|
194 |
+
temp_dict = {}
|
195 |
+
temp_dict["q"] = q
|
196 |
+
temp_dict["a"] = a
|
197 |
+
temp_dict["chunk"] = text
|
198 |
+
temp_dict["emb"] = get_embedding(text).tolist()
|
199 |
+
final_dict[count] = temp_dict
|
200 |
+
count += 1
|
201 |
+
# embedding q
|
202 |
+
for q, a in zip(questions, answers):
|
203 |
+
temp_dict = {}
|
204 |
+
temp_dict["q"] = q
|
205 |
+
temp_dict["a"] = a
|
206 |
+
temp_dict["chunk"] = a
|
207 |
+
temp_dict["emb"] = get_embedding(q).tolist()
|
208 |
+
final_dict[count] = temp_dict
|
209 |
+
count += 1
|
210 |
+
# embedding q+a
|
211 |
+
for q, a in zip(questions, answers):
|
212 |
+
temp_dict = {}
|
213 |
+
temp_dict["q"] = q
|
214 |
+
temp_dict["a"] = a
|
215 |
+
temp_dict["chunk"] = a
|
216 |
+
temp_dict["emb"] = get_embedding(q + a).tolist()
|
217 |
+
final_dict[count] = temp_dict
|
218 |
+
count += 1
|
219 |
+
# embedding a
|
220 |
+
for q, a in zip(questions, answers):
|
221 |
+
temp_dict = {}
|
222 |
+
temp_dict["q"] = q
|
223 |
+
temp_dict["a"] = a
|
224 |
+
temp_dict["chunk"] = a
|
225 |
+
temp_dict["emb"] = get_embedding(a).tolist()
|
226 |
+
final_dict[count] = temp_dict
|
227 |
+
count += 1
|
228 |
+
print(f"finish updating {len(final_dict)} data!")
|
229 |
+
os.makedirs("temp_database", exist_ok=True)
|
230 |
+
save_path = os.path.join(
|
231 |
+
"temp_database/",
|
232 |
+
file_path.split("/")[-1].replace("." + file_path.split(".")[1],
|
233 |
+
".json"),
|
234 |
+
)
|
235 |
+
print(save_path)
|
236 |
+
with open(save_path, "w") as f:
|
237 |
+
json.dump(final_dict, f, ensure_ascii=False, indent=2)
|
238 |
+
return {"knowledge_base": save_path, "type": "QA"}
|
239 |
+
else:
|
240 |
+
loader = UnstructuredFileLoader(file_path)
|
241 |
+
docs = loader.load()
|
242 |
+
text_spiltter = CharacterTextSplitter(chunk_size=200,
|
243 |
+
chunk_overlap=100)
|
244 |
+
docs = text_spiltter.split_text(docs[0].page_content)
|
245 |
+
os.makedirs("temp_database", exist_ok=True)
|
246 |
+
save_path = os.path.join(
|
247 |
+
"temp_database/",
|
248 |
+
file_path.replace("." + file_path.split(".")[1], ".json"))
|
249 |
+
final_dict = {}
|
250 |
+
count = 0
|
251 |
+
for c in tqdm(docs):
|
252 |
+
temp_dict = {}
|
253 |
+
temp_dict["chunk"] = c
|
254 |
+
temp_dict["emb"] = get_embedding(c).tolist()
|
255 |
+
final_dict[count] = temp_dict
|
256 |
+
count += 1
|
257 |
+
print(f"finish updating {len(final_dict)} data!")
|
258 |
+
with open(save_path, "w") as f:
|
259 |
+
json.dump(final_dict, f, ensure_ascii=False, indent=2)
|
260 |
+
return {"knowledge_base": save_path, "type": "UnstructuredFile"}
|
261 |
+
|
262 |
+
def load_knowledge_base_qa(path):
|
263 |
+
"""
|
264 |
+
Load json format knowledge base.
|
265 |
+
"""
|
266 |
+
print("path", path)
|
267 |
+
with open(path, "r") as f:
|
268 |
+
data = json.load(f)
|
269 |
+
embeddings = []
|
270 |
+
questions = []
|
271 |
+
answers = []
|
272 |
+
chunks = []
|
273 |
+
for idx in range(len(data.keys())):
|
274 |
+
embeddings.append(data[str(idx)]["emb"])
|
275 |
+
questions.append(data[str(idx)]["q"])
|
276 |
+
answers.append(data[str(idx)]["a"])
|
277 |
+
chunks.append(data[str(idx)]["chunk"])
|
278 |
+
embeddings = np.array(embeddings, dtype=np.float32)
|
279 |
+
embeddings = torch.from_numpy(embeddings).squeeze()
|
280 |
+
return embeddings, questions, answers, chunks
|
281 |
+
|
282 |
+
|
283 |
+
def load_knowledge_base_UnstructuredFile(path):
|
284 |
+
"""
|
285 |
+
Load json format knowledge base.
|
286 |
+
"""
|
287 |
+
with open(path, "r") as f:
|
288 |
+
data = json.load(f)
|
289 |
+
embeddings = []
|
290 |
+
chunks = []
|
291 |
+
for idx in range(len(data.keys())):
|
292 |
+
embeddings.append(data[str(idx)]["emb"])
|
293 |
+
chunks.append(data[str(idx)]["chunk"])
|
294 |
+
embeddings = np.array(embeddings, dtype=np.float32)
|
295 |
+
embeddings = torch.from_numpy(embeddings).squeeze()
|
296 |
+
return embeddings, chunks
|
297 |
+
|
298 |
+
|
299 |
+
def cos_sim(a: torch.Tensor, b: torch.Tensor):
|
300 |
+
"""
|
301 |
+
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
|
302 |
+
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
|
303 |
+
"""
|
304 |
+
if not isinstance(a, torch.Tensor):
|
305 |
+
a = torch.tensor(a)
|
306 |
+
|
307 |
+
if not isinstance(b, torch.Tensor):
|
308 |
+
b = torch.tensor(b)
|
309 |
+
|
310 |
+
if len(a.shape) == 1:
|
311 |
+
a = a.unsqueeze(0)
|
312 |
+
|
313 |
+
if len(b.shape) == 1:
|
314 |
+
b = b.unsqueeze(0)
|
315 |
+
|
316 |
+
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
|
317 |
+
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
|
318 |
+
return torch.mm(a_norm, b_norm.transpose(0, 1))
|
319 |
+
|
320 |
+
|
321 |
+
def matching_a_b(a, b, requirements=None):
|
322 |
+
a_embedder = get_embedding(a)
|
323 |
+
# 获取embedder
|
324 |
+
b_embeder = get_embedding(b)
|
325 |
+
sim_scores = cos_sim(a_embedder, b_embeder)[0]
|
326 |
+
return sim_scores
|
327 |
+
|
328 |
+
|
329 |
+
def matching_category(inputtext,
|
330 |
+
forest_name,
|
331 |
+
requirements=None,
|
332 |
+
cat_embedder=None,
|
333 |
+
top_k=3):
|
334 |
+
"""
|
335 |
+
Args:
|
336 |
+
inputtext: the category name to be matched
|
337 |
+
forest: search tree
|
338 |
+
top_k: the default three highest scoring results
|
339 |
+
Return:
|
340 |
+
topk matching_result. List[List] [[top1_name,top2_name,top3_name],[top1_score,top2_score,top3_score]]
|
341 |
+
"""
|
342 |
+
|
343 |
+
sim_scores = torch.zeros([100])
|
344 |
+
if inputtext:
|
345 |
+
input_embeder = get_embedding(inputtext)
|
346 |
+
sim_scores = cos_sim(input_embeder, cat_embedder)[0]
|
347 |
+
|
348 |
+
if requirements:
|
349 |
+
requirements = requirements.split(" ")
|
350 |
+
requirements_embedder = get_embedding(requirements)
|
351 |
+
req_scores = cos_sim(requirements_embedder, cat_embedder)
|
352 |
+
req_scores = torch.mean(req_scores, dim=0)
|
353 |
+
total_scores = req_scores
|
354 |
+
else:
|
355 |
+
total_scores = sim_scores
|
356 |
+
|
357 |
+
top_k_cat = torch.topk(total_scores, k=top_k)
|
358 |
+
top_k_score, top_k_idx = top_k_cat[0], top_k_cat[1]
|
359 |
+
top_k_name = [forest_name[top_k_idx[i]] for i in range(0, top_k)]
|
360 |
+
|
361 |
+
return [top_k_name, top_k_score.tolist(), top_k_idx]
|
362 |
+
|
363 |
+
|
364 |
+
def sample_with_order_preserved(lst, num):
|
365 |
+
"""Randomly sample from the list while maintaining the original order."""
|
366 |
+
indices = list(range(len(lst)))
|
367 |
+
sampled_indices = random.sample(indices, num)
|
368 |
+
sampled_indices.sort() # 保持原顺序
|
369 |
+
return [lst[i] for i in sampled_indices]
|
370 |
+
|
371 |
+
|
372 |
+
def limit_values(data, max_values):
|
373 |
+
"""Reduce each key-value list in the dictionary to the specified size, keeping the order of the original list unchanged."""
|
374 |
+
for key, values in data.items():
|
375 |
+
if len(values) > max_values:
|
376 |
+
data[key] = sample_with_order_preserved(values, max_values)
|
377 |
+
return data
|
378 |
+
|
379 |
+
|
380 |
+
def limit_keys(data, max_keys):
|
381 |
+
"""Reduce the dictionary to the specified number of keys."""
|
382 |
+
keys = list(data.keys())
|
383 |
+
if len(keys) > max_keys:
|
384 |
+
keys = sample_with_order_preserved(keys, max_keys)
|
385 |
+
data = {key: data[key] for key in keys}
|
386 |
+
return data
|
387 |
+
|
388 |
+
|
389 |
+
def flatten_dict(nested_dict):
|
390 |
+
"""
|
391 |
+
flatten the dictionary
|
392 |
+
"""
|
393 |
+
flattened_dict = {}
|
394 |
+
for key, value in nested_dict.items():
|
395 |
+
if isinstance(value, dict):
|
396 |
+
flattened_subdict = flatten_dict(value)
|
397 |
+
flattened_dict.update(flattened_subdict)
|
398 |
+
else:
|
399 |
+
flattened_dict[key] = value
|
400 |
+
return flattened_dict
|
401 |
+
|
402 |
+
|
403 |
+
def merge_list(list1, list2):
|
404 |
+
for l in list2:
|
405 |
+
if l not in list1:
|
406 |
+
list1.append(l)
|
407 |
+
return list1
|
408 |
+
|
409 |
+
|
410 |
+
def Search_Engines(req):
|
411 |
+
FETSIZE = eval(os.environ["FETSIZE"]) if "FETSIZE" in os.environ else 5
|
412 |
+
|
413 |
+
new_dict = {"keyword": req, "catLeafName": "", "fetchSize": FETSIZE}
|
414 |
+
url = os.environ["SHOPPING_SEARCH"]
|
415 |
+
res = requests.post(
|
416 |
+
url= url,
|
417 |
+
json=new_dict,
|
418 |
+
)
|
419 |
+
user_dict = json.loads(res.text)
|
420 |
+
if "data" in user_dict.keys():
|
421 |
+
request_items = user_dict["data"]["items"] # 查询到的商品信息JSON
|
422 |
+
top_category = user_dict["data"]["topCategories"]
|
423 |
+
return request_items, top_category
|
424 |
+
else:
|
425 |
+
return []
|
426 |
+
|
427 |
+
|
428 |
+
def search_with_api(requirements, categery):
|
429 |
+
|
430 |
+
FETSIZE = eval(os.environ["FETSIZE"]) if "FETSIZE" in os.environ else 5
|
431 |
+
|
432 |
+
request_items = []
|
433 |
+
all_req_list = requirements.split(" ")
|
434 |
+
count = 0
|
435 |
+
|
436 |
+
while len(request_items) < FETSIZE and len(all_req_list) > 0:
|
437 |
+
if count:
|
438 |
+
all_req_list.pop(0)
|
439 |
+
all_req = (" ").join(all_req_list)
|
440 |
+
if categery not in all_req_list:
|
441 |
+
all_req = all_req + " " + categery
|
442 |
+
now_request_items, top_category = Search_Engines(all_req)
|
443 |
+
request_items = merge_list(request_items, now_request_items)
|
444 |
+
count += 1
|
445 |
+
new_top = []
|
446 |
+
for category in top_category:
|
447 |
+
if "其它" in category or "其它" in category:
|
448 |
+
continue
|
449 |
+
else:
|
450 |
+
new_top.append(category)
|
451 |
+
if len(request_items) > FETSIZE:
|
452 |
+
request_items = request_items[:FETSIZE]
|
453 |
+
return request_items, new_top
|
454 |
+
|
455 |
+
|
456 |
+
|
457 |
+
def get_relevant_history(query,history,embeddings):
|
458 |
+
"""
|
459 |
+
Retrieve a list of key history entries based on a query using semantic search.
|
460 |
+
|
461 |
+
Args:
|
462 |
+
query (str): The input query for which key history is to be retrieved.
|
463 |
+
history (list): A list of historical key entries.
|
464 |
+
embeddings (numpy.ndarray): An array of embedding vectors for historical entries.
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
list: A list of key history entries most similar to the query.
|
468 |
+
"""
|
469 |
+
TOP_K = eval(os.environ["TOP_K"]) if "TOP_K" in os.environ else 2
|
470 |
+
relevant_history = []
|
471 |
+
query_embedding = get_embedding(query)
|
472 |
+
hits = semantic_search(query_embedding, embeddings, top_k=min(TOP_K,embeddings.shape[0]))
|
473 |
+
hits = hits[0]
|
474 |
+
for hit in hits:
|
475 |
+
matching_idx = hit["corpus_id"]
|
476 |
+
try:
|
477 |
+
relevant_history.append(history[matching_idx])
|
478 |
+
except:
|
479 |
+
return []
|
480 |
+
return relevant_history
|
app.py
CHANGED
@@ -95,8 +95,6 @@ class NovelUI(WebUI):
|
|
95 |
|
96 |
def construct_ui(self):
|
97 |
with gr.Blocks(css=gc.CSS) as demo:
|
98 |
-
gr.Markdown("""# Agents""")
|
99 |
-
gr.Markdown("""**Agents** is an open-source library/framework for building autonomous language agents.if you want to know more about **Agents**, please check our<a href="https://arxiv.org/pdf/2309.07870.pdf">📄 Paper</a> and<a href="http://www.aiwaves-agents.com/">📦 Github</a>. Here is a demo of **Agents**. You can use it to write a novel.""")
|
100 |
with gr.Column():
|
101 |
self.progress = gr.HTML(
|
102 |
value=sc.FORMAT.format(
|
@@ -111,6 +109,11 @@ class NovelUI(WebUI):
|
|
111 |
label="Dialog",
|
112 |
height=500
|
113 |
)
|
|
|
|
|
|
|
|
|
|
|
114 |
with gr.Row():
|
115 |
self.text_requirement = gr.Textbox(
|
116 |
placeholder="Requirement of the novel",
|
@@ -145,7 +148,7 @@ class NovelUI(WebUI):
|
|
145 |
# ===============Event Listener===============
|
146 |
self.btn_start.click(
|
147 |
fn=self.btn_start_when_click,
|
148 |
-
inputs=[self.text_requirement],
|
149 |
outputs=[self.chatbot, self.chat_record, self.btn_start, self.text_requirement]
|
150 |
).then(
|
151 |
fn=self.btn_start_after_click,
|
@@ -169,7 +172,7 @@ class NovelUI(WebUI):
|
|
169 |
# ===========================================
|
170 |
self.demo = demo
|
171 |
|
172 |
-
def btn_start_when_click(self, text_requirement:str):
|
173 |
"""
|
174 |
inputs=[self.text_requirement],
|
175 |
outputs=[self.chatbot, self.chat_record, self.btn_start, self.text_requirement]
|
@@ -179,7 +182,7 @@ class NovelUI(WebUI):
|
|
179 |
gr.Chatbot.update(visible=True),\
|
180 |
gr.Button.update(interactive=False, value="Running"),\
|
181 |
gr.Textbox.update(value="", interactive=False)
|
182 |
-
self.send_start_cmd({'requirement': text_requirement})
|
183 |
return
|
184 |
|
185 |
def btn_start_after_click(self, history:List, record):
|
@@ -283,4 +286,4 @@ class NovelUI(WebUI):
|
|
283 |
if __name__ == '__main__':
|
284 |
ui = NovelUI(client_cmd=["python","gradio_backend.py"])
|
285 |
ui.construct_ui()
|
286 |
-
ui.run(
|
|
|
95 |
|
96 |
def construct_ui(self):
|
97 |
with gr.Blocks(css=gc.CSS) as demo:
|
|
|
|
|
98 |
with gr.Column():
|
99 |
self.progress = gr.HTML(
|
100 |
value=sc.FORMAT.format(
|
|
|
109 |
label="Dialog",
|
110 |
height=500
|
111 |
)
|
112 |
+
self.text_api = gr.Textbox(
|
113 |
+
value = self.cache["api_key"],
|
114 |
+
placeholder="openai key",
|
115 |
+
label="Please input valid openai key for gpt-3.5-turbo-16k."
|
116 |
+
)
|
117 |
with gr.Row():
|
118 |
self.text_requirement = gr.Textbox(
|
119 |
placeholder="Requirement of the novel",
|
|
|
148 |
# ===============Event Listener===============
|
149 |
self.btn_start.click(
|
150 |
fn=self.btn_start_when_click,
|
151 |
+
inputs=[self.text_requirement, self.text_api],
|
152 |
outputs=[self.chatbot, self.chat_record, self.btn_start, self.text_requirement]
|
153 |
).then(
|
154 |
fn=self.btn_start_after_click,
|
|
|
172 |
# ===========================================
|
173 |
self.demo = demo
|
174 |
|
175 |
+
def btn_start_when_click(self, text_requirement:str, api_key:str):
|
176 |
"""
|
177 |
inputs=[self.text_requirement],
|
178 |
outputs=[self.chatbot, self.chat_record, self.btn_start, self.text_requirement]
|
|
|
182 |
gr.Chatbot.update(visible=True),\
|
183 |
gr.Button.update(interactive=False, value="Running"),\
|
184 |
gr.Textbox.update(value="", interactive=False)
|
185 |
+
self.send_start_cmd({'requirement': text_requirement, "api_key": api_key})
|
186 |
return
|
187 |
|
188 |
def btn_start_after_click(self, history:List, record):
|
|
|
286 |
if __name__ == '__main__':
|
287 |
ui = NovelUI(client_cmd=["python","gradio_backend.py"])
|
288 |
ui.construct_ui()
|
289 |
+
ui.run()
|
create_sop.py
CHANGED
@@ -37,8 +37,7 @@ def create_sop(folder_name: str = "novel_outline", encoding: str = "utf-8", save
|
|
37 |
sop_file = f"./{save_name}.json"
|
38 |
sop_dict = {
|
39 |
"config": {
|
40 |
-
"API_KEY": "sk-
|
41 |
-
"PROXY": "",
|
42 |
"MAX_CHAT_HISTORY" : "100",
|
43 |
"TOP_K" : "1",
|
44 |
"ACTIVE_MODE" : "0",
|
|
|
37 |
sop_file = f"./{save_name}.json"
|
38 |
sop_dict = {
|
39 |
"config": {
|
40 |
+
"API_KEY": "sk-xxxxxxxxxxxxxxxxxxxx",
|
|
|
41 |
"MAX_CHAT_HISTORY" : "100",
|
42 |
"TOP_K" : "1",
|
43 |
"ACTIVE_MODE" : "0",
|
gradio_backend.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
import sys
|
2 |
sys.path.append("./novel-server")
|
3 |
-
|
4 |
import yaml
|
5 |
import os
|
6 |
import argparse
|
7 |
import random
|
8 |
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from
|
12 |
from gradio_base import Client
|
13 |
-
from
|
14 |
|
15 |
from myagent import Node, MyAgent, ask_gpt
|
16 |
from typing import List, Tuple
|
@@ -62,10 +62,13 @@ if __name__ == "__main__":
|
|
62 |
"agents_name": ['Elmo','Abby', 'Zoe', 'Ernie', 'Bert', 'Oscar'],
|
63 |
"nodes_name": ['Node 1','Node 2','Node 3', 'Node 4', 'state1', 'state2', 'state3', 'state4'],
|
64 |
"output_file_path": f"{os.getcwd()+'/novel_outline'}",
|
65 |
-
"requirement": NOVEL_PROMPT['Node 1']["task"]
|
|
|
66 |
}
|
67 |
)
|
68 |
client.listening_for_start_()
|
|
|
|
|
69 |
NOVEL_PROMPT['Node 1']['task'] = client.cache['requirement']
|
70 |
print("Received: ", client.cache['requirement'])
|
71 |
outline = run_node_1(
|
@@ -92,5 +95,4 @@ if __name__ == "__main__":
|
|
92 |
show_in_gradio(30, str(name_list), " ", " ")
|
93 |
|
94 |
agents,sop,environment = init("novel_outline.json")
|
95 |
-
run(agents,sop,environment)
|
96 |
-
|
|
|
1 |
import sys
|
2 |
sys.path.append("./novel-server")
|
3 |
+
sys.path.append("agents")
|
4 |
import yaml
|
5 |
import os
|
6 |
import argparse
|
7 |
import random
|
8 |
|
9 |
+
from SOP import SOP
|
10 |
+
from Agent import Agent
|
11 |
+
from Environment import Environment
|
12 |
from gradio_base import Client
|
13 |
+
from Memory import Memory
|
14 |
|
15 |
from myagent import Node, MyAgent, ask_gpt
|
16 |
from typing import List, Tuple
|
|
|
62 |
"agents_name": ['Elmo','Abby', 'Zoe', 'Ernie', 'Bert', 'Oscar'],
|
63 |
"nodes_name": ['Node 1','Node 2','Node 3', 'Node 4', 'state1', 'state2', 'state3', 'state4'],
|
64 |
"output_file_path": f"{os.getcwd()+'/novel_outline'}",
|
65 |
+
"requirement": NOVEL_PROMPT['Node 1']["task"],
|
66 |
+
"api_key": "sk-xxxxxxxxxxxxxxxxxxxx"
|
67 |
}
|
68 |
)
|
69 |
client.listening_for_start_()
|
70 |
+
os.environ["API_KEY"] = client.cache["api_key"]
|
71 |
+
MyAgent.API_KEY = client.cache["api_key"]
|
72 |
NOVEL_PROMPT['Node 1']['task'] = client.cache['requirement']
|
73 |
print("Received: ", client.cache['requirement'])
|
74 |
outline = run_node_1(
|
|
|
95 |
show_in_gradio(30, str(name_list), " ", " ")
|
96 |
|
97 |
agents,sop,environment = init("novel_outline.json")
|
98 |
+
run(agents,sop,environment)
|
|