fak111 commited on
Commit
aea420f
·
1 Parent(s): 3840495
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import gradio as gr
4
+ import requests
5
+ from lagent.schema import AgentStatusCode
6
+ import os
7
+ os.system("python -m mindsearch.app --lang cn --model_format internlm_silicon &")
8
+ PLANNER_HISTORY = []
9
+ SEARCHER_HISTORY = []
10
+
11
+ def rst_mem(history_planner: list, history_searcher: list):
12
+ '''
13
+ Reset the chatbot memory.
14
+ '''
15
+ history_planner = []
16
+ history_searcher = []
17
+ if PLANNER_HISTORY:
18
+ PLANNER_HISTORY.clear()
19
+ return history_planner, history_searcher
20
+
21
+ def format_response(gr_history, agent_return):
22
+ if agent_return['state'] in [
23
+ AgentStatusCode.STREAM_ING, AgentStatusCode.ANSWER_ING
24
+ ]:
25
+ gr_history[-1][1] = agent_return['response']
26
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_START:
27
+ thought = gr_history[-1][1].split('```')[0]
28
+ if agent_return['response'].startswith('```'):
29
+ gr_history[-1][1] = thought + '\n' + agent_return['response']
30
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_END:
31
+ thought = gr_history[-1][1].split('```')[0]
32
+ if isinstance(agent_return['response'], dict):
33
+ gr_history[-1][
34
+ 1] = thought + '\n' + f'```json\n{json.dumps(agent_return["response"], ensure_ascii=False, indent=4)}\n```' # noqa: E501
35
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_RETURN:
36
+ assert agent_return['inner_steps'][-1]['role'] == 'environment'
37
+ item = agent_return['inner_steps'][-1]
38
+ gr_history.append([
39
+ None,
40
+ f"```json\n{json.dumps(item['content'], ensure_ascii=False, indent=4)}\n```"
41
+ ])
42
+ gr_history.append([None, ''])
43
+ return
44
+
45
+ def predict(history_planner, history_searcher):
46
+
47
+ def streaming(raw_response):
48
+ for chunk in raw_response.iter_lines(chunk_size=8192,
49
+ decode_unicode=False,
50
+ delimiter=b'\n'):
51
+ if chunk:
52
+ decoded = chunk.decode('utf-8')
53
+ if decoded == '\r':
54
+ continue
55
+ if decoded[:6] == 'data: ':
56
+ decoded = decoded[6:]
57
+ elif decoded.startswith(': ping - '):
58
+ continue
59
+ response = json.loads(decoded)
60
+ yield (response['response'], response['current_node'])
61
+
62
+ global PLANNER_HISTORY
63
+ PLANNER_HISTORY.append(dict(role='user', content=history_planner[-1][0]))
64
+ new_search_turn = True
65
+
66
+ url = 'http://localhost:8002/solve'
67
+ headers = {'Content-Type': 'application/json'}
68
+ data = {'inputs': PLANNER_HISTORY}
69
+ raw_response = requests.post(url,
70
+ headers=headers,
71
+ data=json.dumps(data),
72
+ timeout=20,
73
+ stream=True)
74
+
75
+ for resp in streaming(raw_response):
76
+ agent_return, node_name = resp
77
+ if node_name:
78
+ if node_name in ['root', 'response']:
79
+ continue
80
+ agent_return = agent_return['nodes'][node_name]['detail']
81
+ if new_search_turn:
82
+ history_searcher.append([agent_return['content'], ''])
83
+ new_search_turn = False
84
+ format_response(history_searcher, agent_return)
85
+ if agent_return['state'] == AgentStatusCode.END:
86
+ new_search_turn = True
87
+ yield history_planner, history_searcher
88
+ else:
89
+ new_search_turn = True
90
+ format_response(history_planner, agent_return)
91
+ if agent_return['state'] == AgentStatusCode.END:
92
+ PLANNER_HISTORY = agent_return['inner_steps']
93
+ yield history_planner, history_searcher
94
+ return history_planner, history_searcher
95
+
96
+ examples = [
97
+ ["Find legal precedents in contract law."],
98
+ ["What are the top 10 e-commerce websites?"],
99
+ ["Generate a report on global climate change."],
100
+ ]
101
+ import os
102
+ css_path = os.path.join(os.path.dirname(__file__), "css", "test1.css")
103
+ with gr.Blocks(css=css_path) as demo:
104
+ with gr.Column(elem_classes="chat-box"):
105
+ gr.HTML("""<h1 align="center">MindSearch Gradio Demo</h1>""")
106
+ gr.HTML("""<p style="text-align: center; font-family: Arial, sans-serif;">
107
+ MindSearch is an open-source AI Search Engine Framework with Perplexity.ai Pro performance. You can deploy your own Perplexity.ai-style search engine using either closed-source LLMs (GPT, Claude)
108
+ or open-source LLMs (InternLM2.5-7b-chat).</p> """)
109
+ gr.HTML("""
110
+ <div style="text-align: center; font-size: 16px;">
111
+ <a href="https://github.com/InternLM/MindSearch" style="margin-right: 15px; text-decoration: none; color: #4A90E2;" target="_blank">🔗 GitHub</a>
112
+ <a href="https://arxiv.org/abs/2407.20183" style="margin-right: 15px; text-decoration: none; color: #4A90E2;" target="_blank">📄 Arxiv</a>
113
+ <a href="https://huggingface.co/papers/2407.20183" style="margin-right: 15px; text-decoration: none; color: #4A90E2;" target="_blank">📚 Hugging Face Papers</a>
114
+ <a href="https://huggingface.co/spaces/internlm/MindSearch" style="text-decoration: none; color: #4A90E2;" target="_blank">🤗 Hugging Face Demo</a>
115
+ </div>""")
116
+ gr.HTML("""
117
+ <h1 align='right'><img src='https://raw.githubusercontent.com/InternLM/MindSearch/98fd84d566fe9e3adc5028727f72f2944098fd05/assets/logo.svg' alt='MindSearch Logo1' class="logo"></h1>
118
+ """)
119
+
120
+ with gr.Row():
121
+ with gr.Column(scale=10):
122
+ with gr.Row():
123
+ with gr.Column():
124
+ planner = gr.Chatbot(label='planner',
125
+ show_label=True,
126
+ show_copy_button=True,
127
+ bubble_full_width=False,
128
+ render_markdown=True,
129
+ elem_classes="chatbot-container")
130
+ with gr.Column():
131
+ searcher = gr.Chatbot(label='searcher',
132
+ show_label=True,
133
+ show_copy_button=True,
134
+ bubble_full_width=False,
135
+ render_markdown=True,
136
+ elem_classes="chatbot-container")
137
+
138
+ with gr.Row(elem_classes="chat-box"):
139
+ # Text input area
140
+ user_input = gr.Textbox(
141
+ show_label=False,
142
+ placeholder="Type your message...",
143
+ lines=1,
144
+ container=False,
145
+ elem_classes="editor"
146
+ )
147
+ # Buttons (now in the same Row)
148
+ submitBtn = gr.Button("submit", variant="primary", elem_classes="toolbarButton")
149
+ clearBtn = gr.Button("clear", variant="secondary", elem_classes="toolbarButton")
150
+ with gr.Row(elem_classes="examples-container"):
151
+ examples_component = gr.Examples(examples, inputs=user_input,
152
+ label="Try these examples:")
153
+
154
+ def user(query, history):
155
+ return '', history + [[query, '']]
156
+
157
+ def submit_example(example):
158
+ return user(example[0], planner.value)
159
+
160
+ submitBtn.click(user, [user_input, planner], [user_input, planner],
161
+ queue=False).then(predict, [planner, searcher],
162
+ [planner, searcher])
163
+ clearBtn.click(rst_mem, [planner, searcher], [planner, searcher],
164
+ queue=False)
165
+
166
+ demo.queue()
167
+ demo.launch(server_name='127.0.0.1',
168
+ server_port=7884,
169
+ inbrowser=True,
170
+ share=True)
mindsearch/__pycache__/app.cpython-310.pyc ADDED
Binary file (4.26 kB). View file
 
mindsearch/agent/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+
4
+ from lagent.actions import ActionExecutor, BingBrowser
5
+
6
+ import mindsearch.agent.models as llm_factory
7
+ from mindsearch.agent.mindsearch_agent import (MindSearchAgent,
8
+ MindSearchProtocol)
9
+ from mindsearch.agent.mindsearch_prompt import (
10
+ FINAL_RESPONSE_CN, FINAL_RESPONSE_EN, GRAPH_PROMPT_CN, GRAPH_PROMPT_EN,
11
+ fewshot_example_cn, fewshot_example_en, graph_fewshot_example_cn,
12
+ graph_fewshot_example_en, searcher_context_template_cn,
13
+ searcher_context_template_en, searcher_input_template_cn,
14
+ searcher_input_template_en, searcher_system_prompt_cn,
15
+ searcher_system_prompt_en)
16
+
17
+ LLM = {}
18
+
19
+
20
+ def init_agent(lang='cn', model_format='internlm_server',search_engine='DuckDuckGoSearch'):
21
+ llm = LLM.get(model_format, None)
22
+ if llm is None:
23
+ llm_cfg = getattr(llm_factory, model_format)
24
+ if llm_cfg is None:
25
+ raise NotImplementedError
26
+ llm_cfg = llm_cfg.copy()
27
+ llm = llm_cfg.pop('type')(**llm_cfg)
28
+ LLM[model_format] = llm
29
+
30
+ interpreter_prompt = GRAPH_PROMPT_CN if lang == 'cn' else GRAPH_PROMPT_EN
31
+ plugin_prompt = searcher_system_prompt_cn if lang == 'cn' else searcher_system_prompt_en
32
+ if not model_format.lower().startswith('internlm'):
33
+ interpreter_prompt += graph_fewshot_example_cn if lang == 'cn' else graph_fewshot_example_en
34
+ plugin_prompt += fewshot_example_cn if lang == 'cn' else fewshot_example_en
35
+
36
+ agent = MindSearchAgent(
37
+ llm=llm,
38
+ protocol=MindSearchProtocol(meta_prompt=datetime.now().strftime(
39
+ 'The current date is %Y-%m-%d.'),
40
+ interpreter_prompt=interpreter_prompt,
41
+ response_prompt=FINAL_RESPONSE_CN
42
+ if lang == 'cn' else FINAL_RESPONSE_EN),
43
+ searcher_cfg=dict(
44
+ llm=llm,
45
+ plugin_executor=ActionExecutor(
46
+ BingBrowser(searcher_type=search_engine,
47
+ topk=6,
48
+ api_key=os.environ.get('BING_API_KEY',
49
+ 'YOUR BING API'))),
50
+ protocol=MindSearchProtocol(
51
+ meta_prompt=datetime.now().strftime(
52
+ 'The current date is %Y-%m-%d.'),
53
+ plugin_prompt=plugin_prompt,
54
+ ),
55
+ template=dict(input=searcher_input_template_cn
56
+ if lang == 'cn' else searcher_input_template_en,
57
+ context=searcher_context_template_cn
58
+ if lang == 'cn' else searcher_context_template_en)),
59
+ max_turn=10)
60
+ return agent
mindsearch/agent/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
mindsearch/agent/__pycache__/mindsearch_agent.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
mindsearch/agent/__pycache__/mindsearch_prompt.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
mindsearch/agent/__pycache__/models.cpython-310.pyc ADDED
Binary file (1.69 kB). View file
 
mindsearch/agent/mindsearch_agent.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import queue
4
+ import random
5
+ import re
6
+ import threading
7
+ import uuid
8
+ from collections import defaultdict
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+ from copy import deepcopy
11
+ from dataclasses import asdict
12
+ from typing import Dict, List, Optional
13
+
14
+ from lagent.actions import ActionExecutor
15
+ from lagent.agents import BaseAgent, Internlm2Agent
16
+ from lagent.agents.internlm2_agent import Internlm2Protocol
17
+ from lagent.schema import AgentReturn, AgentStatusCode, ModelStatusCode
18
+ from termcolor import colored
19
+
20
+ # 初始化日志记录
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class SearcherAgent(Internlm2Agent):
26
+
27
+ def __init__(self, template='{query}', **kwargs) -> None:
28
+ super().__init__(**kwargs)
29
+ self.template = template
30
+
31
+ def stream_chat(self,
32
+ question: str,
33
+ root_question: str = None,
34
+ parent_response: List[dict] = None,
35
+ **kwargs) -> AgentReturn:
36
+ message = self.template['input'].format(question=question,
37
+ topic=root_question)
38
+ if parent_response:
39
+ if 'context' in self.template:
40
+ parent_response = [
41
+ self.template['context'].format(**item)
42
+ for item in parent_response
43
+ ]
44
+ message = '\n'.join(parent_response + [message])
45
+ print(colored(f'current query: {message}', 'green'))
46
+ for agent_return in super().stream_chat(message,
47
+ session_id=random.randint(
48
+ 0, 999999),
49
+ **kwargs):
50
+ agent_return.type = 'searcher'
51
+ agent_return.content = question
52
+ yield deepcopy(agent_return)
53
+
54
+
55
+ class MindSearchProtocol(Internlm2Protocol):
56
+
57
+ def __init__(
58
+ self,
59
+ meta_prompt: str = None,
60
+ interpreter_prompt: str = None,
61
+ plugin_prompt: str = None,
62
+ few_shot: Optional[List] = None,
63
+ response_prompt: str = None,
64
+ language: Dict = dict(
65
+ begin='',
66
+ end='',
67
+ belong='assistant',
68
+ ),
69
+ tool: Dict = dict(
70
+ begin='{start_token}{name}\n',
71
+ start_token='<|action_start|>',
72
+ name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
73
+ belong='assistant',
74
+ end='<|action_end|>\n',
75
+ ),
76
+ execute: Dict = dict(role='execute',
77
+ begin='',
78
+ end='',
79
+ fallback_role='environment'),
80
+ ) -> None:
81
+ self.response_prompt = response_prompt
82
+ super().__init__(meta_prompt=meta_prompt,
83
+ interpreter_prompt=interpreter_prompt,
84
+ plugin_prompt=plugin_prompt,
85
+ few_shot=few_shot,
86
+ language=language,
87
+ tool=tool,
88
+ execute=execute)
89
+
90
+ def format(self,
91
+ inner_step: List[Dict],
92
+ plugin_executor: ActionExecutor = None,
93
+ **kwargs) -> list:
94
+ formatted = []
95
+ if self.meta_prompt:
96
+ formatted.append(dict(role='system', content=self.meta_prompt))
97
+ if self.plugin_prompt:
98
+ plugin_prompt = self.plugin_prompt.format(tool_info=json.dumps(
99
+ plugin_executor.get_actions_info(), ensure_ascii=False))
100
+ formatted.append(
101
+ dict(role='system', content=plugin_prompt, name='plugin'))
102
+ if self.interpreter_prompt:
103
+ formatted.append(
104
+ dict(role='system',
105
+ content=self.interpreter_prompt,
106
+ name='interpreter'))
107
+ if self.few_shot:
108
+ for few_shot in self.few_shot:
109
+ formatted += self.format_sub_role(few_shot)
110
+ formatted += self.format_sub_role(inner_step)
111
+ return formatted
112
+
113
+
114
+ class WebSearchGraph:
115
+ end_signal = 'end'
116
+ searcher_cfg = dict()
117
+
118
+ def __init__(self):
119
+ self.nodes = {}
120
+ self.adjacency_list = defaultdict(list)
121
+ self.executor = ThreadPoolExecutor(max_workers=10)
122
+ self.future_to_query = dict()
123
+ self.searcher_resp_queue = queue.Queue()
124
+
125
+ def add_root_node(self, node_content, node_name='root'):
126
+ self.nodes[node_name] = dict(content=node_content, type='root')
127
+ self.adjacency_list[node_name] = []
128
+ self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
129
+
130
+ def add_node(self, node_name, node_content):
131
+ self.nodes[node_name] = dict(content=node_content, type='searcher')
132
+ self.adjacency_list[node_name] = []
133
+
134
+ def model_stream_thread():
135
+ agent = SearcherAgent(**self.searcher_cfg)
136
+ try:
137
+ parent_nodes = []
138
+ for start_node, adj in self.adjacency_list.items():
139
+ for neighbor in adj:
140
+ if node_name == neighbor[
141
+ 'name'] and start_node in self.nodes and 'response' in self.nodes[
142
+ start_node]:
143
+ parent_nodes.append(self.nodes[start_node])
144
+ parent_response = [
145
+ dict(question=node['content'], answer=node['response'])
146
+ for node in parent_nodes
147
+ ]
148
+ for answer in agent.stream_chat(
149
+ node_content,
150
+ self.nodes['root']['content'],
151
+ parent_response=parent_response):
152
+ self.searcher_resp_queue.put(
153
+ deepcopy((node_name,
154
+ dict(response=answer.response,
155
+ detail=answer), [])))
156
+ self.nodes[node_name]['response'] = answer.response
157
+ self.nodes[node_name]['detail'] = answer
158
+ except Exception as e:
159
+ logger.exception(f'Error in model_stream_thread: {e}')
160
+
161
+ self.future_to_query[self.executor.submit(
162
+ model_stream_thread)] = f'{node_name}-{node_content}'
163
+
164
+ def add_response_node(self, node_name='response'):
165
+ self.nodes[node_name] = dict(type='end')
166
+ self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
167
+
168
+ def add_edge(self, start_node, end_node):
169
+ self.adjacency_list[start_node].append(
170
+ dict(id=str(uuid.uuid4()), name=end_node, state=2))
171
+ self.searcher_resp_queue.put((start_node, self.nodes[start_node],
172
+ self.adjacency_list[start_node]))
173
+
174
+ def reset(self):
175
+ self.nodes = {}
176
+ self.adjacency_list = defaultdict(list)
177
+
178
+ def node(self, node_name):
179
+ return self.nodes[node_name].copy()
180
+
181
+
182
+ class MindSearchAgent(BaseAgent):
183
+
184
+ def __init__(self,
185
+ llm,
186
+ searcher_cfg,
187
+ protocol=MindSearchProtocol(),
188
+ max_turn=10):
189
+ self.local_dict = {}
190
+ self.ptr = 0
191
+ self.llm = llm
192
+ self.max_turn = max_turn
193
+ WebSearchGraph.searcher_cfg = searcher_cfg
194
+ super().__init__(llm=llm, action_executor=None, protocol=protocol)
195
+
196
+ def stream_chat(self, message, **kwargs):
197
+ if isinstance(message, str):
198
+ message = [{'role': 'user', 'content': message}]
199
+ elif isinstance(message, dict):
200
+ message = [message]
201
+ as_dict = kwargs.pop('as_dict', False)
202
+ return_early = kwargs.pop('return_early', False)
203
+ self.local_dict.clear()
204
+ self.ptr = 0
205
+ inner_history = message[:]
206
+ agent_return = AgentReturn()
207
+ agent_return.type = 'planner'
208
+ agent_return.nodes = {}
209
+ agent_return.adjacency_list = {}
210
+ agent_return.inner_steps = deepcopy(inner_history)
211
+ for _ in range(self.max_turn):
212
+ prompt = self._protocol.format(inner_step=inner_history)
213
+ code = None
214
+ for model_state, response, _ in self.llm.stream_chat(
215
+ prompt, session_id=random.randint(0, 999999), **kwargs):
216
+ if model_state.value < 0:
217
+ agent_return.state = getattr(AgentStatusCode,
218
+ model_state.name)
219
+ yield deepcopy(agent_return)
220
+ return
221
+ response = response.replace('<|plugin|>', '<|interpreter|>')
222
+ _, language, action = self._protocol.parse(response)
223
+ if not language and not action:
224
+ continue
225
+ code = action['parameters']['command'] if action else ''
226
+ agent_return.state = self._determine_agent_state(
227
+ model_state, code, agent_return)
228
+ agent_return.response = language if not code else code
229
+
230
+ # if agent_return.state == AgentStatusCode.STREAM_ING:
231
+ yield deepcopy(agent_return)
232
+
233
+ inner_history.append({'role': 'language', 'content': language})
234
+ print(colored(response, 'blue'))
235
+
236
+ if code:
237
+ yield from self._process_code(agent_return, inner_history,
238
+ code, as_dict, return_early)
239
+ else:
240
+ agent_return.state = AgentStatusCode.END
241
+ yield deepcopy(agent_return)
242
+ return
243
+
244
+ agent_return.state = AgentStatusCode.END
245
+ yield deepcopy(agent_return)
246
+
247
+ def _determine_agent_state(self, model_state, code, agent_return):
248
+ if code:
249
+ return (AgentStatusCode.PLUGIN_START if model_state
250
+ == ModelStatusCode.END else AgentStatusCode.PLUGIN_START)
251
+ return (AgentStatusCode.ANSWER_ING
252
+ if agent_return.nodes and 'response' in agent_return.nodes else
253
+ AgentStatusCode.STREAM_ING)
254
+
255
+ def _process_code(self,
256
+ agent_return,
257
+ inner_history,
258
+ code,
259
+ as_dict=False,
260
+ return_early=False):
261
+ for node_name, node, adj in self.execute_code(
262
+ code, return_early=return_early):
263
+ if as_dict and 'detail' in node:
264
+ node['detail'] = asdict(node['detail'])
265
+ if not adj:
266
+ agent_return.nodes[node_name] = node
267
+ else:
268
+ agent_return.adjacency_list[node_name] = adj
269
+ # state 1进行中,2未开始,3已结束
270
+ for start_node, neighbors in agent_return.adjacency_list.items():
271
+ for neighbor in neighbors:
272
+ if neighbor['name'] not in agent_return.nodes:
273
+ state = 2
274
+ elif 'detail' not in agent_return.nodes[neighbor['name']]:
275
+ state = 2
276
+ elif agent_return.nodes[neighbor['name']][
277
+ 'detail'].state == AgentStatusCode.END:
278
+ state = 3
279
+ else:
280
+ state = 1
281
+ neighbor['state'] = state
282
+ if not adj:
283
+ yield deepcopy((agent_return, node_name))
284
+ reference, references_url = self._generate_reference(
285
+ agent_return, code, as_dict)
286
+ inner_history.append({
287
+ 'role': 'tool',
288
+ 'content': code,
289
+ 'name': 'plugin'
290
+ })
291
+ inner_history.append({
292
+ 'role': 'environment',
293
+ 'content': reference,
294
+ 'name': 'plugin'
295
+ })
296
+ agent_return.inner_steps = deepcopy(inner_history)
297
+ agent_return.state = AgentStatusCode.PLUGIN_RETURN
298
+ agent_return.references.update(references_url)
299
+ yield deepcopy(agent_return)
300
+
301
+ def _generate_reference(self, agent_return, code, as_dict):
302
+ node_list = [
303
+ node.strip().strip('\"') for node in re.findall(
304
+ r'graph\.node\("((?:[^"\\]|\\.)*?)"\)', code)
305
+ ]
306
+ if 'add_response_node' in code:
307
+ return self._protocol.response_prompt, dict()
308
+ references = []
309
+ references_url = dict()
310
+ for node_name in node_list:
311
+ ref_results = None
312
+ ref2url = None
313
+ if as_dict:
314
+ actions = agent_return.nodes[node_name]['detail']['actions']
315
+ else:
316
+ actions = agent_return.nodes[node_name]['detail'].actions
317
+ if actions:
318
+ ref_results = actions[0]['result'][0][
319
+ 'content'] if as_dict else actions[0].result[0]['content']
320
+ if ref_results:
321
+ ref_results = json.loads(ref_results)
322
+ ref2url = {
323
+ idx: item['url']
324
+ for idx, item in ref_results.items()
325
+ }
326
+
327
+ ref = f"## {node_name}\n\n{agent_return.nodes[node_name]['response']}\n"
328
+ updated_ref = re.sub(
329
+ r'\[\[(\d+)\]\]',
330
+ lambda match: f'[[{int(match.group(1)) + self.ptr}]]', ref)
331
+ numbers = [int(n) for n in re.findall(r'\[\[(\d+)\]\]', ref)]
332
+ if numbers:
333
+ try:
334
+ assert all(str(elem) in ref2url for elem in numbers)
335
+ except Exception as exc:
336
+ logger.info(f'Illegal reference id: {str(exc)}')
337
+ if ref2url:
338
+ references_url.update({
339
+ str(idx + self.ptr): ref2url[str(idx)]
340
+ for idx in set(numbers) if str(idx) in ref2url
341
+ })
342
+ self.ptr += max(numbers) + 1
343
+ references.append(updated_ref)
344
+ return '\n'.join(references), references_url
345
+
346
+ def execute_code(self, command: str, return_early=False):
347
+
348
+ def extract_code(text: str) -> str:
349
+ text = re.sub(r'from ([\w.]+) import WebSearchGraph', '', text)
350
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
351
+ single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
352
+ if triple_match:
353
+ return triple_match.group(1)
354
+ elif single_match:
355
+ return single_match.group(1)
356
+ return text
357
+
358
+ def run_command(cmd):
359
+ try:
360
+ exec(cmd, globals(), self.local_dict)
361
+ plan_graph = self.local_dict.get('graph')
362
+ assert plan_graph is not None
363
+ for future in as_completed(plan_graph.future_to_query):
364
+ future.result()
365
+ plan_graph.future_to_query.clear()
366
+ plan_graph.searcher_resp_queue.put(plan_graph.end_signal)
367
+ except Exception as e:
368
+ logger.exception(f'Error executing code: {e}')
369
+ raise
370
+
371
+ command = extract_code(command)
372
+ producer_thread = threading.Thread(target=run_command,
373
+ args=(command, ))
374
+ producer_thread.start()
375
+
376
+ responses = defaultdict(list)
377
+ ordered_nodes = []
378
+ active_node = None
379
+
380
+ while True:
381
+ try:
382
+ item = self.local_dict.get('graph').searcher_resp_queue.get(
383
+ timeout=60)
384
+ if item is WebSearchGraph.end_signal:
385
+ for node_name in ordered_nodes:
386
+ # resp = None
387
+ for resp in responses[node_name]:
388
+ yield deepcopy(resp)
389
+ # if resp:
390
+ # assert resp[1][
391
+ # 'detail'].state == AgentStatusCode.END
392
+ break
393
+ node_name, node, adj = item
394
+ if node_name in ['root', 'response']:
395
+ yield deepcopy((node_name, node, adj))
396
+ else:
397
+ if node_name not in ordered_nodes:
398
+ ordered_nodes.append(node_name)
399
+ responses[node_name].append((node_name, node, adj))
400
+ if not active_node and ordered_nodes:
401
+ active_node = ordered_nodes[0]
402
+ while active_node and responses[active_node]:
403
+ if return_early:
404
+ if 'detail' in responses[active_node][-1][
405
+ 1] and responses[active_node][-1][1][
406
+ 'detail'].state == AgentStatusCode.END:
407
+ item = responses[active_node][-1]
408
+ else:
409
+ item = responses[active_node].pop(0)
410
+ else:
411
+ item = responses[active_node].pop(0)
412
+ if 'detail' in item[1] and item[1][
413
+ 'detail'].state == AgentStatusCode.END:
414
+ ordered_nodes.pop(0)
415
+ responses[active_node].clear()
416
+ active_node = None
417
+ yield deepcopy(item)
418
+ except queue.Empty:
419
+ if not producer_thread.is_alive():
420
+ break
421
+ producer_thread.join()
422
+ return
mindsearch/agent/mindsearch_prompt.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ searcher_system_prompt_cn = """## 人物简介
4
+ 你是一个可以调用网络搜索工具的智能助手。请根据"当前问题",调用搜索工具收集信息并回复问题。你能够调用如下工具:
5
+ {tool_info}
6
+ ## 回复格式
7
+
8
+ 调用工具时,请按照以下格式:
9
+ ```
10
+ 你的思考过程...<|action_start|><|plugin|>{{"name": "tool_name", "parameters": {{"param1": "value1"}}}}<|action_end|>
11
+ ```
12
+
13
+ ## 要求
14
+
15
+ - 回答中每个关键点需标注引用的搜索结果来源,以确保信息的可信度。给出索引的形式为`[[int]]`,如果有多个索引,则用多个[[]]表示,如`[[id_1]][[id_2]]`。
16
+ - 基于"当前问题"的搜索结果,撰写详细完备的回复,优先回答"当前问题"。
17
+
18
+ """
19
+
20
+ searcher_system_prompt_en = """## Character Introduction
21
+ You are an intelligent assistant that can call web search tools. Please collect information and reply to the question based on the current problem. You can use the following tools:
22
+ {tool_info}
23
+ ## Reply Format
24
+
25
+ When calling the tool, please follow the format below:
26
+ ```
27
+ Your thought process...<|action_start|><|plugin|>{{"name": "tool_name", "parameters": {{"param1": "value1"}}}}<|action_end|>
28
+ ```
29
+
30
+ ## Requirements
31
+
32
+ - Each key point in the response should be marked with the source of the search results to ensure the credibility of the information. The citation format is `[[int]]`. If there are multiple citations, use multiple [[]] to provide the index, such as `[[id_1]][[id_2]]`.
33
+ - Based on the search results of the "current problem", write a detailed and complete reply to answer the "current problem".
34
+ """
35
+
36
+ fewshot_example_cn = """
37
+ ## 样例
38
+
39
+ ### search
40
+ 当我希望搜索"王者荣耀现在是什么赛季"时,我会按照以下格式进行操作:
41
+ 现在是2024年,因此我应该搜索王者荣耀赛季关键词<|action_start|><|plugin|>{{"name": "FastWebBrowser.search", "parameters": {{"query": ["王者荣耀 赛季", "2024年王者荣耀赛季"]}}}}<|action_end|>
42
+
43
+ ### select
44
+ 为了找到王者荣耀s36赛季最强射手,我需要寻找提及王者荣耀s36射手的网页。初步浏览网页后,发现网页0提到王者荣耀s36赛季的信息,但没有具体提及射手的相关信息。网页3提到“s36最强射手出现?”,有可能包含最强射手信息。网页13提到“四大T0英雄崛起,射手荣耀降临”,可能包含最强射手的信息。因此,我选择了网页3和网页13进行进一步阅读。<|action_start|><|plugin|>{{"name": "FastWebBrowser.select", "parameters": {{"index": [3, 13]}}}}<|action_end|>
45
+ """
46
+
47
+ fewshot_example_en = """
48
+ ## Example
49
+
50
+ ### search
51
+ When I want to search for "What season is Honor of Kings now", I will operate in the following format:
52
+ Now it is 2024, so I should search for the keyword of the Honor of Kings<|action_start|><|plugin|>{{"name": "FastWebBrowser.search", "parameters": {{"query": ["Honor of Kings Season", "season for Honor of Kings in 2024"]}}}}<|action_end|>
53
+
54
+ ### select
55
+ In order to find the strongest shooters in Honor of Kings in season s36, I needed to look for web pages that mentioned shooters in Honor of Kings in season s36. After an initial browse of the web pages, I found that web page 0 mentions information about Honor of Kings in s36 season, but there is no specific mention of information about the shooter. Webpage 3 mentions that “the strongest shooter in s36 has appeared?”, which may contain information about the strongest shooter. Webpage 13 mentions “Four T0 heroes rise, archer's glory”, which may contain information about the strongest archer. Therefore, I chose webpages 3 and 13 for further reading.<|action_start|><|plugin|>{{"name": "FastWebBrowser.select", "parameters": {{"index": [3, 13]}}}}<|action_end|>
56
+ """
57
+
58
+ searcher_input_template_en = """## Final Problem
59
+ {topic}
60
+ ## Current Problem
61
+ {question}
62
+ """
63
+
64
+ searcher_input_template_cn = """## 主问题
65
+ {topic}
66
+ ## 当前问题
67
+ {question}
68
+ """
69
+
70
+ searcher_context_template_en = """## Historical Problem
71
+ {question}
72
+ Answer: {answer}
73
+ """
74
+
75
+ searcher_context_template_cn = """## 历史问题
76
+ {question}
77
+ 回答:{answer}
78
+ """
79
+
80
+ search_template_cn = '## {query}\n\n{result}\n'
81
+ search_template_en = '## {query}\n\n{result}\n'
82
+
83
+ GRAPH_PROMPT_CN = """## 人物简介
84
+ 你是一个可以利用 Jupyter 环境 Python 编程的程序员。你可以利用提供的 API 来构建 Web 搜索图,最终生成代码并执行。
85
+
86
+ ## API 介绍
87
+
88
+ 下面是包含属性详细说明的 `WebSearchGraph` 类的 API 文档:
89
+
90
+ ### 类:`WebSearchGraph`
91
+
92
+ 此类用于管理网络搜索图的节点和边,并通过网络代理进行搜索。
93
+
94
+ #### 初始化方法
95
+
96
+ 初始化 `WebSearchGraph` 实例。
97
+
98
+ **属性:**
99
+
100
+ - `nodes` (Dict[str, Dict[str, str]]): 存储图中所有节点的字典。每个节点由其名称索引,并包含内容、类型以及其他相关信息。
101
+ - `adjacency_list` (Dict[str, List[str]]): 存储图中所有节点之间连接关系的邻接表。每个节点由其名称索引,并包含一个相邻节点名称的列表。
102
+
103
+
104
+ #### 方法:`add_root_node`
105
+
106
+ 添加原始问题作为根节点。
107
+ **参数:**
108
+
109
+ - `node_content` (str): 用户提出的问题。
110
+ - `node_name` (str, 可选): 节点名称,默认为 'root'。
111
+
112
+
113
+ #### 方法:`add_node`
114
+
115
+ 添加搜索子问题节点并返回搜索结果。
116
+ **参数:
117
+
118
+ - `node_name` (str): 节点名称。
119
+ - `node_content` (str): 子问题内容。
120
+
121
+ **返回:**
122
+
123
+ - `str`: 返回搜索结果。
124
+
125
+
126
+ #### 方法:`add_response_node`
127
+
128
+ 当前获取的信息已经满足问题需求,添加回复节点。
129
+
130
+ **参数:**
131
+
132
+ - `node_name` (str, 可选): 节点名称,默认为 'response'。
133
+
134
+
135
+ #### 方法:`add_edge`
136
+
137
+ 添加边。
138
+
139
+ **参数:**
140
+
141
+ - `start_node` (str): 起始节点名称。
142
+ - `end_node` (str): 结束节点名称。
143
+
144
+
145
+ #### 方法:`reset`
146
+
147
+ 重置节点和边。
148
+
149
+
150
+ #### 方法:`node`
151
+
152
+ 获取节点信息。
153
+
154
+ ```python
155
+ def node(self, node_name: str) -> str
156
+ ```
157
+
158
+ **参数:**
159
+
160
+ - `node_name` (str): 节点名称。
161
+
162
+ **返回:**
163
+
164
+ - `str`: 返回包含节点信息的字典,包含节点的内容、类型、思考过程(如果有)和前驱节点列表。
165
+
166
+ ## 任务介绍
167
+ 通过将一个问题拆分成能够通过搜索回答的子问题(没有关联的问题可以同步并列搜索),每个搜索的问题应该是一个单一问题,即单个具体人、事、物、具体时间点、地点或知识点的问题,不是一个复合问题(比如某个时间段), 一步步构建搜索图,最终回答问题。
168
+
169
+ ## 注意事项
170
+
171
+ 1. 注意,每个搜索节点的内容必须单个问题,不要包含多个问题(比如同时问多个知识点的问题或者多个事物的比较加筛选,类似 A, B, C 有什么区别,那个价格在哪个区间 -> 分别查询)
172
+ 2. 不要杜撰搜索结果,要等待代码返回结果
173
+ 3. 同样的问题不要重复提问,可以在已有问题的基础上继续提问
174
+ 4. 添加 response 节点的时候,要单独添加,不要和其他节点一起添加,不能同时添加 response 节点和其他节点
175
+ 5. 一次输出中,不要包含多个代码块,每次只能有一个代码块
176
+ 6. 每个代码块应该放置在一个代码块标记中,同时生成完代码后添加一个<|action_end|>标志,如下所示:
177
+ <|action_start|><|interpreter|>```python
178
+ # 你的代码块
179
+ ```<|action_end|>
180
+ 7. 最后一次回复应该是添加node_name为'response'的 response 节点,必须添加 response 节点,不要添加其他节点
181
+ """
182
+
183
+ GRAPH_PROMPT_EN = """## Character Profile
184
+ You are a programmer capable of Python programming in a Jupyter environment. You can utilize the provided API to construct a Web Search Graph, ultimately generating and executing code.
185
+
186
+ ## API Description
187
+
188
+ Below is the API documentation for the WebSearchGraph class, including detailed attribute descriptions:
189
+
190
+ ### Class: WebSearchGraph
191
+
192
+ This class manages nodes and edges of a web search graph and conducts searches via a web proxy.
193
+
194
+ #### Initialization Method
195
+
196
+ Initializes an instance of WebSearchGraph.
197
+
198
+ **Attributes:**
199
+
200
+ - nodes (Dict[str, Dict[str, str]]): A dictionary storing all nodes in the graph. Each node is indexed by its name and contains content, type, and other related information.
201
+ - adjacency_list (Dict[str, List[str]]): An adjacency list storing the connections between all nodes in the graph. Each node is indexed by its name and contains a list of adjacent node names.
202
+
203
+ #### Method: add_root_node
204
+
205
+ Adds the initial question as the root node.
206
+ **Parameters:**
207
+
208
+ - node_content (str): The user's question.
209
+ - node_name (str, optional): The node name, default is 'root'.
210
+
211
+ #### Method: add_node
212
+
213
+ Adds a sub-question node and returns search results.
214
+ **Parameters:**
215
+
216
+ - node_name (str): The node name.
217
+ - node_content (str): The sub-question content.
218
+
219
+ **Returns:**
220
+
221
+ - str: Returns the search results.
222
+
223
+ #### Method: add_response_node
224
+
225
+ Adds a response node when the current information satisfies the question's requirements.
226
+
227
+ **Parameters:**
228
+
229
+ - node_name (str, optional): The node name, default is 'response'.
230
+
231
+ #### Method: add_edge
232
+
233
+ Adds an edge.
234
+
235
+ **Parameters:**
236
+
237
+ - start_node (str): The starting node name.
238
+ - end_node (str): The ending node name.
239
+
240
+ #### Method: reset
241
+
242
+ Resets nodes and edges.
243
+
244
+ #### Method: node
245
+
246
+ Get node information.
247
+
248
+ python
249
+ def node(self, node_name: str) -> str
250
+
251
+ **Parameters:**
252
+
253
+ - node_name (str): The node name.
254
+
255
+ **Returns:**
256
+
257
+ - str: Returns a dictionary containing the node's information, including content, type, thought process (if any), and list of predecessor nodes.
258
+
259
+ ## Task Description
260
+ By breaking down a question into sub-questions that can be answered through searches (unrelated questions can be searched concurrently), each search query should be a single question focusing on a specific person, event, object, specific time point, location, or knowledge point. It should not be a compound question (e.g., a time period). Step by step, build the search graph to finally answer the question.
261
+
262
+ ## Considerations
263
+
264
+ 1. Each search node's content must be a single question; do not include multiple questions (e.g., do not ask multiple knowledge points or compare and filter multiple things simultaneously, like asking for differences between A, B, and C, or price ranges -> query each separately).
265
+ 2. Do not fabricate search results; wait for the code to return results.
266
+ 3. Do not repeat the same question; continue asking based on existing questions.
267
+ 4. When adding a response node, add it separately; do not add a response node and other nodes simultaneously.
268
+ 5. In a single output, do not include multiple code blocks; only one code block per output.
269
+ 6. Each code block should be placed within a code block marker, and after generating the code, add an <|action_end|> tag as shown below:
270
+ <|action_start|><|interpreter|>
271
+ ```python
272
+ # Your code block (Note that the 'Get new added node information' logic must be added at the end of the code block, such as 'graph.node('...')')
273
+ ```<|action_end|>
274
+ 7. The final response should add a response node with node_name 'response', and no other nodes should be added.
275
+ """
276
+
277
+ graph_fewshot_example_cn = """
278
+ ## 返回格式示例
279
+ <|action_start|><|interpreter|>```python
280
+ graph = WebSearchGraph()
281
+ graph.add_root_node(node_content="哪家大模型API最便宜?", node_name="root") # 添加原始问题作为根节点
282
+ graph.add_node(
283
+ node_name="大模型API提供商", # 节点名称最好有意义
284
+ node_content="目前有哪些主要的大模型API提供商?")
285
+ graph.add_node(
286
+ node_name="sub_name_2", # 节点名称最好有意义
287
+ node_content="content of sub_name_2")
288
+ ...
289
+ graph.add_edge(start_node="root", end_node="sub_name_1")
290
+ ...
291
+ graph.node("大模型API提供商"), graph.node("sub_name_2"), ...
292
+ ```<|action_end|>
293
+ """
294
+
295
+ graph_fewshot_example_en = """
296
+ ## Response Format
297
+ <|action_start|><|interpreter|>```python
298
+ graph = WebSearchGraph()
299
+ graph.add_root_node(node_content="Which large model API is the cheapest?", node_name="root") # Add the original question as the root node
300
+ graph.add_node(
301
+ node_name="Large Model API Providers", # The node name should be meaningful
302
+ node_content="Who are the main large model API providers currently?")
303
+ graph.add_node(
304
+ node_name="sub_name_2", # The node name should be meaningful
305
+ node_content="content of sub_name_2")
306
+ ...
307
+ graph.add_edge(start_node="root", end_node="sub_name_1")
308
+ ...
309
+ # Get node info
310
+ graph.node("Large Model API Providers"), graph.node("sub_name_2"), ...
311
+ ```<|action_end|>
312
+ """
313
+
314
+ FINAL_RESPONSE_CN = """基于提供的问答对,撰写一篇详细完备的最终回答。
315
+ - 回答内容需要逻辑清晰,层次分明,确保读者易于理解。
316
+ - 回答中每个关键点需标注引用的搜索结果来源(保持跟问答对中的索引一致),以确保信息的可信度。给出索引的形式为`[[int]]`,如果有多个索引,则用多个[[]]表示,如`[[id_1]][[id_2]]`。
317
+ - 回答部分需要全面且完备,不要出现"基于上述内容"等模糊表达,最终呈现的回答不包括提供给你的问答对。
318
+ - 语言风格需要专业、严谨,避免口语化表达。
319
+ - 保持统一的语法和词汇使用,确保整体文档的一致性和连贯性。"""
320
+
321
+ FINAL_RESPONSE_EN = """Based on the provided Q&A pairs, write a detailed and comprehensive final response.
322
+ - The response content should be logically clear and well-structured to ensure reader understanding.
323
+ - Each key point in the response should be marked with the source of the search results (consistent with the indices in the Q&A pairs) to ensure information credibility. The index is in the form of `[[int]]`, and if there are multiple indices, use multiple `[[]]`, such as `[[id_1]][[id_2]]`.
324
+ - The response should be comprehensive and complete, without vague expressions like "based on the above content". The final response should not include the Q&A pairs provided to you.
325
+ - The language style should be professional and rigorous, avoiding colloquial expressions.
326
+ - Maintain consistent grammar and vocabulary usage to ensure overall document consistency and coherence."""
mindsearch/agent/models.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from lagent.llms import (GPTAPI, INTERNLM2_META, HFTransformerCasualLM,
4
+ LMDeployClient, LMDeployServer)
5
+
6
+ internlm_server = dict(type=LMDeployServer,
7
+ path='internlm/internlm2_5-7b-chat',
8
+ model_name='internlm2',
9
+ meta_template=INTERNLM2_META,
10
+ top_p=0.8,
11
+ top_k=1,
12
+ temperature=0,
13
+ max_new_tokens=8192,
14
+ repetition_penalty=1.02,
15
+ stop_words=['<|im_end|>'])
16
+
17
+ internlm_client = dict(type=LMDeployClient,
18
+ model_name='internlm2_5-7b-chat',
19
+ url='http://127.0.0.1:23333',
20
+ meta_template=INTERNLM2_META,
21
+ top_p=0.8,
22
+ top_k=1,
23
+ temperature=0,
24
+ max_new_tokens=8192,
25
+ repetition_penalty=1.02,
26
+ stop_words=['<|im_end|>'])
27
+
28
+ internlm_hf = dict(type=HFTransformerCasualLM,
29
+ path='internlm/internlm2_5-7b-chat',
30
+ meta_template=INTERNLM2_META,
31
+ top_p=0.8,
32
+ top_k=None,
33
+ temperature=1e-6,
34
+ max_new_tokens=8192,
35
+ repetition_penalty=1.02,
36
+ stop_words=['<|im_end|>'])
37
+ # openai_api_base needs to fill in the complete chat api address, such as: https://api.openai.com/v1/chat/completions
38
+ gpt4 = dict(type=GPTAPI,
39
+ model_type='gpt-4-turbo',
40
+ key=os.environ.get('OPENAI_API_KEY', 'YOUR OPENAI API KEY'),
41
+ openai_api_base=os.environ.get('OPENAI_API_BASE', 'https://api.openai.com/v1/chat/completions'),
42
+ )
43
+
44
+ url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation'
45
+ qwen = dict(type=GPTAPI,
46
+ model_type='qwen-max-longcontext',
47
+ key=os.environ.get('QWEN_API_KEY', 'YOUR QWEN API KEY'),
48
+ openai_api_base=url,
49
+ meta_template=[
50
+ dict(role='system', api_role='system'),
51
+ dict(role='user', api_role='user'),
52
+ dict(role='assistant', api_role='assistant'),
53
+ dict(role='environment', api_role='system')
54
+ ],
55
+ top_p=0.8,
56
+ top_k=1,
57
+ temperature=0,
58
+ max_new_tokens=4096,
59
+ repetition_penalty=1.02,
60
+ stop_words=['<|im_end|>'])
61
+
62
+ internlm_silicon = dict(type=GPTAPI,
63
+ model_type='internlm/internlm2_5-7b-chat',
64
+ key=os.environ.get('SILICON_API_KEY', 'YOUR SILICON API KEY'),
65
+ openai_api_base='https://api.siliconflow.cn/v1/chat/completions',
66
+ meta_template=[
67
+ dict(role='system', api_role='system'),
68
+ dict(role='user', api_role='user'),
69
+ dict(role='assistant', api_role='assistant'),
70
+ dict(role='environment', api_role='system')
71
+ ],
72
+ top_p=0.8,
73
+ top_k=1,
74
+ temperature=0,
75
+ max_new_tokens=8192,
76
+ repetition_penalty=1.02,
77
+ stop_words=['<|im_end|>'])
mindsearch/app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from copy import deepcopy
5
+ from dataclasses import asdict
6
+ from typing import Dict, List, Union
7
+
8
+ import janus
9
+ from fastapi import FastAPI
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from lagent.schema import AgentStatusCode
12
+ from pydantic import BaseModel
13
+ from sse_starlette.sse import EventSourceResponse
14
+
15
+ from mindsearch.agent import init_agent
16
+
17
+
18
+ def parse_arguments():
19
+ import argparse
20
+ parser = argparse.ArgumentParser(description='MindSearch API')
21
+ parser.add_argument('--lang', default='cn', type=str, help='Language')
22
+ parser.add_argument('--model_format',
23
+ default='internlm_server',
24
+ type=str,
25
+ help='Model format')
26
+ parser.add_argument('--search_engine',
27
+ default='DuckDuckGoSearch',
28
+ type=str,
29
+ help='Search engine')
30
+ return parser.parse_args()
31
+
32
+
33
+ args = parse_arguments()
34
+ app = FastAPI(docs_url='/')
35
+
36
+ app.add_middleware(CORSMiddleware,
37
+ allow_origins=['*'],
38
+ allow_credentials=True,
39
+ allow_methods=['*'],
40
+ allow_headers=['*'])
41
+
42
+
43
+ class GenerationParams(BaseModel):
44
+ inputs: Union[str, List[Dict]]
45
+ agent_cfg: Dict = dict()
46
+
47
+
48
+ @app.post('/solve')
49
+ async def run(request: GenerationParams):
50
+
51
+ def convert_adjacency_to_tree(adjacency_input, root_name):
52
+
53
+ def build_tree(node_name):
54
+ node = {'name': node_name, 'children': []}
55
+ if node_name in adjacency_input:
56
+ for child in adjacency_input[node_name]:
57
+ child_node = build_tree(child['name'])
58
+ child_node['state'] = child['state']
59
+ child_node['id'] = child['id']
60
+ node['children'].append(child_node)
61
+ return node
62
+
63
+ return build_tree(root_name)
64
+
65
+ async def generate():
66
+ try:
67
+ queue = janus.Queue()
68
+ stop_event = asyncio.Event()
69
+
70
+ # Wrapping a sync generator as an async generator using run_in_executor
71
+ def sync_generator_wrapper():
72
+ try:
73
+ for response in agent.stream_chat(inputs):
74
+ queue.sync_q.put(response)
75
+ except Exception as e:
76
+ logging.exception(
77
+ f'Exception in sync_generator_wrapper: {e}')
78
+ finally:
79
+ # Notify async_generator_wrapper that the data generation is complete.
80
+ queue.sync_q.put(None)
81
+
82
+ async def async_generator_wrapper():
83
+ loop = asyncio.get_event_loop()
84
+ loop.run_in_executor(None, sync_generator_wrapper)
85
+ while True:
86
+ response = await queue.async_q.get()
87
+ if response is None: # Ensure that all elements are consumed
88
+ break
89
+ yield response
90
+ if not isinstance(
91
+ response,
92
+ tuple) and response.state == AgentStatusCode.END:
93
+ break
94
+ stop_event.set() # Inform sync_generator_wrapper to stop
95
+
96
+ async for response in async_generator_wrapper():
97
+ if isinstance(response, tuple):
98
+ agent_return, node_name = response
99
+ else:
100
+ agent_return = response
101
+ node_name = None
102
+ origin_adj = deepcopy(agent_return.adjacency_list)
103
+ adjacency_list = convert_adjacency_to_tree(
104
+ agent_return.adjacency_list, 'root')
105
+ assert adjacency_list[
106
+ 'name'] == 'root' and 'children' in adjacency_list
107
+ agent_return.adjacency_list = adjacency_list['children']
108
+ agent_return = asdict(agent_return)
109
+ agent_return['adj'] = origin_adj
110
+ response_json = json.dumps(dict(response=agent_return,
111
+ current_node=node_name),
112
+ ensure_ascii=False)
113
+ yield {'data': response_json}
114
+ # yield f'data: {response_json}\n\n'
115
+ except Exception as exc:
116
+ msg = 'An error occurred while generating the response.'
117
+ logging.exception(msg)
118
+ response_json = json.dumps(
119
+ dict(error=dict(msg=msg, details=str(exc))),
120
+ ensure_ascii=False)
121
+ yield {'data': response_json}
122
+ # yield f'data: {response_json}\n\n'
123
+ finally:
124
+ await stop_event.wait(
125
+ ) # Waiting for async_generator_wrapper to stop
126
+ queue.close()
127
+ await queue.wait_closed()
128
+
129
+ inputs = request.inputs
130
+ agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine)
131
+ return EventSourceResponse(generate())
132
+
133
+
134
+ if __name__ == '__main__':
135
+ import uvicorn
136
+ uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info')
mindsearch/terminal.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ from lagent.actions import ActionExecutor, BingBrowser
4
+ from lagent.llms import INTERNLM2_META, LMDeployServer
5
+
6
+ from mindsearch.agent.mindsearch_agent import (MindSearchAgent,
7
+ MindSearchProtocol)
8
+ from mindsearch.agent.mindsearch_prompt import (
9
+ FINAL_RESPONSE_CN, FINAL_RESPONSE_EN, GRAPH_PROMPT_CN, GRAPH_PROMPT_EN,
10
+ searcher_context_template_cn, searcher_context_template_en,
11
+ searcher_input_template_cn, searcher_input_template_en,
12
+ searcher_system_prompt_cn, searcher_system_prompt_en)
13
+
14
+ lang = 'cn'
15
+ llm = LMDeployServer(path='internlm/internlm2_5-7b-chat',
16
+ model_name='internlm2',
17
+ meta_template=INTERNLM2_META,
18
+ top_p=0.8,
19
+ top_k=1,
20
+ temperature=0,
21
+ max_new_tokens=8192,
22
+ repetition_penalty=1.02,
23
+ stop_words=['<|im_end|>'])
24
+
25
+ agent = MindSearchAgent(
26
+ llm=llm,
27
+ protocol=MindSearchProtocol(
28
+ meta_prompt=datetime.now().strftime('The current date is %Y-%m-%d.'),
29
+ interpreter_prompt=GRAPH_PROMPT_CN
30
+ if lang == 'cn' else GRAPH_PROMPT_EN,
31
+ response_prompt=FINAL_RESPONSE_CN
32
+ if lang == 'cn' else FINAL_RESPONSE_EN),
33
+ searcher_cfg=dict(
34
+ llm=llm,
35
+ plugin_executor=ActionExecutor(
36
+ BingBrowser(searcher_type='DuckDuckGoSearch', topk=6)),
37
+ protocol=MindSearchProtocol(
38
+ meta_prompt=datetime.now().strftime(
39
+ 'The current date is %Y-%m-%d.'),
40
+ plugin_prompt=searcher_system_prompt_cn
41
+ if lang == 'cn' else searcher_system_prompt_en,
42
+ ),
43
+ template=dict(input=searcher_input_template_cn
44
+ if lang == 'cn' else searcher_input_template_en,
45
+ context=searcher_context_template_cn
46
+ if lang == 'cn' else searcher_context_template_en)),
47
+ max_turn=10)
48
+
49
+ for agent_return in agent.stream_chat('上海今天适合穿什么衣服'):
50
+ pass
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ duckduckgo_search==5.3.1b1
2
+ einops
3
+ fastapi
4
+ git+https://github.com/InternLM/lagent.git
5
+ gradio
6
+ janus
7
+ lmdeploy
8
+ pyvis
9
+ sse-starlette
10
+ termcolor
11
+ transformers==4.41.0
12
+ uvicorn
13
+ class_registry==2.1.2