| | import asyncio |
| | import json |
| | import logging |
| | from copy import deepcopy |
| | from dataclasses import asdict |
| | from typing import Dict, List, Union |
| |
|
| | import janus |
| | from fastapi import FastAPI |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from lagent.schema import AgentStatusCode |
| | from pydantic import BaseModel |
| | from sse_starlette.sse import EventSourceResponse |
| |
|
| | from mindsearch.agent import init_agent |
| |
|
| |
|
| | def parse_arguments(): |
| | import argparse |
| | parser = argparse.ArgumentParser(description='MindSearch API') |
| | parser.add_argument('--lang', default='cn', type=str, help='Language') |
| | parser.add_argument('--model_format', |
| | default='internlm_server', |
| | type=str, |
| | help='Model format') |
| | parser.add_argument('--search_engine', |
| | default='DuckDuckGoSearch', |
| | type=str, |
| | help='Search engine') |
| | return parser.parse_args() |
| |
|
| |
|
| | args = parse_arguments() |
| | app = FastAPI(docs_url='/') |
| |
|
| | app.add_middleware(CORSMiddleware, |
| | allow_origins=['*'], |
| | allow_credentials=True, |
| | allow_methods=['*'], |
| | allow_headers=['*']) |
| |
|
| |
|
| | class GenerationParams(BaseModel): |
| | inputs: Union[str, List[Dict]] |
| | agent_cfg: Dict = dict() |
| |
|
| |
|
| | @app.post('/solve') |
| | async def run(request: GenerationParams): |
| |
|
| | def convert_adjacency_to_tree(adjacency_input, root_name): |
| |
|
| | def build_tree(node_name): |
| | node = {'name': node_name, 'children': []} |
| | if node_name in adjacency_input: |
| | for child in adjacency_input[node_name]: |
| | child_node = build_tree(child['name']) |
| | child_node['state'] = child['state'] |
| | child_node['id'] = child['id'] |
| | node['children'].append(child_node) |
| | return node |
| |
|
| | return build_tree(root_name) |
| |
|
| | async def generate(): |
| | try: |
| | queue = janus.Queue() |
| | stop_event = asyncio.Event() |
| |
|
| | |
| | def sync_generator_wrapper(): |
| | try: |
| | for response in agent.stream_chat(inputs): |
| | queue.sync_q.put(response) |
| | except Exception as e: |
| | logging.exception( |
| | f'Exception in sync_generator_wrapper: {e}') |
| | finally: |
| | |
| | queue.sync_q.put(None) |
| |
|
| | async def async_generator_wrapper(): |
| | loop = asyncio.get_event_loop() |
| | loop.run_in_executor(None, sync_generator_wrapper) |
| | while True: |
| | response = await queue.async_q.get() |
| | if response is None: |
| | break |
| | yield response |
| | if not isinstance( |
| | response, |
| | tuple) and response.state == AgentStatusCode.END: |
| | break |
| | stop_event.set() |
| |
|
| | async for response in async_generator_wrapper(): |
| | if isinstance(response, tuple): |
| | agent_return, node_name = response |
| | else: |
| | agent_return = response |
| | node_name = None |
| | origin_adj = deepcopy(agent_return.adjacency_list) |
| | adjacency_list = convert_adjacency_to_tree( |
| | agent_return.adjacency_list, 'root') |
| | assert adjacency_list[ |
| | 'name'] == 'root' and 'children' in adjacency_list |
| | agent_return.adjacency_list = adjacency_list['children'] |
| | agent_return = asdict(agent_return) |
| | agent_return['adj'] = origin_adj |
| | response_json = json.dumps(dict(response=agent_return, |
| | current_node=node_name), |
| | ensure_ascii=False) |
| | yield {'data': response_json} |
| | |
| | except Exception as exc: |
| | msg = 'An error occurred while generating the response.' |
| | logging.exception(msg) |
| | response_json = json.dumps( |
| | dict(error=dict(msg=msg, details=str(exc))), |
| | ensure_ascii=False) |
| | yield {'data': response_json} |
| | |
| | finally: |
| | await stop_event.wait( |
| | ) |
| | queue.close() |
| | await queue.wait_closed() |
| |
|
| | inputs = request.inputs |
| | agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine) |
| | return EventSourceResponse(generate()) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | import uvicorn |
| | uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info') |
| |
|