leagend commited on
Commit
63b2f51
·
verified ·
1 Parent(s): ce912c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -134
app.py CHANGED
@@ -1,136 +1,151 @@
1
- import asyncio
2
  import json
3
- import logging
4
- from copy import deepcopy
5
- from dataclasses import asdict
6
- from typing import Dict, List, Union
7
-
8
- import janus
9
- from fastapi import FastAPI
10
- from fastapi.middleware.cors import CORSMiddleware
11
  from lagent.schema import AgentStatusCode
12
- from pydantic import BaseModel
13
- from sse_starlette.sse import EventSourceResponse
14
-
15
- from mindsearch.agent import init_agent
16
-
17
-
18
- def parse_arguments():
19
- import argparse
20
- parser = argparse.ArgumentParser(description='MindSearch API')
21
- parser.add_argument('--lang', default='cn', type=str, help='Language')
22
- parser.add_argument('--model_format',
23
- default='internlm_server',
24
- type=str,
25
- help='Model format')
26
- parser.add_argument('--search_engine',
27
- default='DuckDuckGoSearch',
28
- type=str,
29
- help='Search engine')
30
- return parser.parse_args()
31
-
32
-
33
- args = parse_arguments()
34
- app = FastAPI(docs_url='/')
35
-
36
- app.add_middleware(CORSMiddleware,
37
- allow_origins=['*'],
38
- allow_credentials=True,
39
- allow_methods=['*'],
40
- allow_headers=['*'])
41
-
42
-
43
- class GenerationParams(BaseModel):
44
- inputs: Union[str, List[Dict]]
45
- agent_cfg: Dict = dict()
46
-
47
-
48
- @app.post('/solve')
49
- async def run(request: GenerationParams):
50
-
51
- def convert_adjacency_to_tree(adjacency_input, root_name):
52
-
53
- def build_tree(node_name):
54
- node = {'name': node_name, 'children': []}
55
- if node_name in adjacency_input:
56
- for child in adjacency_input[node_name]:
57
- child_node = build_tree(child['name'])
58
- child_node['state'] = child['state']
59
- child_node['id'] = child['id']
60
- node['children'].append(child_node)
61
- return node
62
-
63
- return build_tree(root_name)
64
-
65
- async def generate():
66
- try:
67
- queue = janus.Queue()
68
- stop_event = asyncio.Event()
69
-
70
- # Wrapping a sync generator as an async generator using run_in_executor
71
- def sync_generator_wrapper():
72
- try:
73
- for response in agent.stream_chat(inputs):
74
- queue.sync_q.put(response)
75
- except Exception as e:
76
- logging.exception(
77
- f'Exception in sync_generator_wrapper: {e}')
78
- finally:
79
- # Notify async_generator_wrapper that the data generation is complete.
80
- queue.sync_q.put(None)
81
-
82
- async def async_generator_wrapper():
83
- loop = asyncio.get_event_loop()
84
- loop.run_in_executor(None, sync_generator_wrapper)
85
- while True:
86
- response = await queue.async_q.get()
87
- if response is None: # Ensure that all elements are consumed
88
- break
89
- yield response
90
- if not isinstance(
91
- response,
92
- tuple) and response.state == AgentStatusCode.END:
93
- break
94
- stop_event.set() # Inform sync_generator_wrapper to stop
95
-
96
- async for response in async_generator_wrapper():
97
- if isinstance(response, tuple):
98
- agent_return, node_name = response
99
- else:
100
- agent_return = response
101
- node_name = None
102
- origin_adj = deepcopy(agent_return.adjacency_list)
103
- adjacency_list = convert_adjacency_to_tree(
104
- agent_return.adjacency_list, 'root')
105
- assert adjacency_list[
106
- 'name'] == 'root' and 'children' in adjacency_list
107
- agent_return.adjacency_list = adjacency_list['children']
108
- agent_return = asdict(agent_return)
109
- agent_return['adj'] = origin_adj
110
- response_json = json.dumps(dict(response=agent_return,
111
- current_node=node_name),
112
- ensure_ascii=False)
113
- yield {'data': response_json}
114
- # yield f'data: {response_json}\n\n'
115
- except Exception as exc:
116
- msg = 'An error occurred while generating the response.'
117
- logging.exception(msg)
118
- response_json = json.dumps(
119
- dict(error=dict(msg=msg, details=str(exc))),
120
- ensure_ascii=False)
121
- yield {'data': response_json}
122
- # yield f'data: {response_json}\n\n'
123
- finally:
124
- await stop_event.wait(
125
- ) # Waiting for async_generator_wrapper to stop
126
- queue.close()
127
- await queue.wait_closed()
128
-
129
- inputs = request.inputs
130
- agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine)
131
- return EventSourceResponse(generate())
132
-
133
-
134
- if __name__ == '__main__':
135
- import uvicorn
136
- uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+
3
+ import gradio as gr
4
+ import requests
 
 
 
 
 
5
  from lagent.schema import AgentStatusCode
6
+
7
+ PLANNER_HISTORY = []
8
+ SEARCHER_HISTORY = []
9
+
10
+ import subprocess
11
+
12
+ # 定义要执行的命令行命令
13
+ command = "python -m mindsearch.app --lang cn --model_format internlm_silicon --search_engine DuckDuckGoSearch &" # 将 "your_command" 替换为你要执行的命令
14
+
15
+ # 使用subprocess.Popen来启动命令并将其放在后台
16
+ process = subprocess.Popen(command, shell=True)
17
+
18
+ print(f"Command '{command}' started with PID: {process.pid}")
19
+
20
+ def rst_mem(history_planner: list, history_searcher: list):
21
+ '''
22
+ Reset the chatbot memory.
23
+ '''
24
+ history_planner = []
25
+ history_searcher = []
26
+ if PLANNER_HISTORY:
27
+ PLANNER_HISTORY.clear()
28
+ return history_planner, history_searcher
29
+
30
+
31
+ def format_response(gr_history, agent_return):
32
+ if agent_return['state'] in [
33
+ AgentStatusCode.STREAM_ING, AgentStatusCode.ANSWER_ING
34
+ ]:
35
+ gr_history[-1][1] = agent_return['response']
36
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_START:
37
+ thought = gr_history[-1][1].split('```')[0]
38
+ if agent_return['response'].startswith('```'):
39
+ gr_history[-1][1] = thought + '\n' + agent_return['response']
40
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_END:
41
+ thought = gr_history[-1][1].split('```')[0]
42
+ if isinstance(agent_return['response'], dict):
43
+ gr_history[-1][
44
+ 1] = thought + '\n' + f'```json\n{json.dumps(agent_return["response"], ensure_ascii=False, indent=4)}\n```' # noqa: E501
45
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_RETURN:
46
+ assert agent_return['inner_steps'][-1]['role'] == 'environment'
47
+ item = agent_return['inner_steps'][-1]
48
+ gr_history.append([
49
+ None,
50
+ f"```json\n{json.dumps(item['content'], ensure_ascii=False, indent=4)}\n```"
51
+ ])
52
+ gr_history.append([None, ''])
53
+ return
54
+
55
+
56
+ def predict(history_planner, history_searcher):
57
+
58
+ def streaming(raw_response):
59
+ for chunk in raw_response.iter_lines(chunk_size=8192,
60
+ decode_unicode=False,
61
+ delimiter=b'\n'):
62
+ if chunk:
63
+ decoded = chunk.decode('utf-8')
64
+ if decoded == '\r':
65
+ continue
66
+ if decoded[:6] == 'data: ':
67
+ decoded = decoded[6:]
68
+ elif decoded.startswith(': ping - '):
69
+ continue
70
+ response = json.loads(decoded)
71
+ yield (response['response'], response['current_node'])
72
+
73
+ global PLANNER_HISTORY
74
+ PLANNER_HISTORY.append(dict(role='user', content=history_planner[-1][0]))
75
+ new_search_turn = True
76
+
77
+ url = 'http://localhost:8002/solve'
78
+ headers = {'Content-Type': 'application/json'}
79
+ data = {'inputs': PLANNER_HISTORY}
80
+ raw_response = requests.post(url,
81
+ headers=headers,
82
+ data=json.dumps(data),
83
+ timeout=20,
84
+ stream=True)
85
+
86
+ for resp in streaming(raw_response):
87
+ agent_return, node_name = resp
88
+ if node_name:
89
+ if node_name in ['root', 'response']:
90
+ continue
91
+ agent_return = agent_return['nodes'][node_name]['detail']
92
+ if new_search_turn:
93
+ history_searcher.append([agent_return['content'], ''])
94
+ new_search_turn = False
95
+ format_response(history_searcher, agent_return)
96
+ if agent_return['state'] == AgentStatusCode.END:
97
+ new_search_turn = True
98
+ yield history_planner, history_searcher
99
+ else:
100
+ new_search_turn = True
101
+ format_response(history_planner, agent_return)
102
+ if agent_return['state'] == AgentStatusCode.END:
103
+ PLANNER_HISTORY = agent_return['inner_steps']
104
+ yield history_planner, history_searcher
105
+ return history_planner, history_searcher
106
+
107
+
108
+ with gr.Blocks() as demo:
109
+ gr.HTML("""<h1 align="center">WebAgent Gradio Simple Demo</h1>""")
110
+ with gr.Row():
111
+ with gr.Column(scale=10):
112
+ with gr.Row():
113
+ with gr.Column():
114
+ planner = gr.Chatbot(label='planner',
115
+ height=700,
116
+ show_label=True,
117
+ show_copy_button=True,
118
+ bubble_full_width=False,
119
+ render_markdown=True)
120
+ with gr.Column():
121
+ searcher = gr.Chatbot(label='searcher',
122
+ height=700,
123
+ show_label=True,
124
+ show_copy_button=True,
125
+ bubble_full_width=False,
126
+ render_markdown=True)
127
+ with gr.Row():
128
+ user_input = gr.Textbox(show_label=False,
129
+ placeholder='inputs...',
130
+ lines=5,
131
+ container=False)
132
+ with gr.Row():
133
+ with gr.Column(scale=2):
134
+ submitBtn = gr.Button('Submit')
135
+ with gr.Column(scale=1, min_width=20):
136
+ emptyBtn = gr.Button('Clear History')
137
+
138
+ def user(query, history):
139
+ return '', history + [[query, '']]
140
+
141
+ submitBtn.click(user, [user_input, planner], [user_input, planner],
142
+ queue=False).then(predict, [planner, searcher],
143
+ [planner, searcher])
144
+ emptyBtn.click(rst_mem, [planner, searcher], [planner, searcher],
145
+ queue=False)
146
+
147
+ demo.queue()
148
+ demo.launch(server_name='127.0.0.1',
149
+ server_port=7882,
150
+ inbrowser=True,
151
+ share=True)