Spaces:
Running
on
L4
Running
on
L4
| import asyncio | |
| import queue | |
| import random | |
| import re | |
| import uuid | |
| from collections import defaultdict | |
| from concurrent.futures import ThreadPoolExecutor | |
| from copy import deepcopy | |
| from threading import Thread | |
| from typing import Dict, List | |
| from lagent.actions import BaseAction | |
| from lagent.schema import AgentMessage, AgentStatusCode | |
| from .streaming import AsyncStreamingAgentForInternLM, StreamingAgentForInternLM | |
| class SearcherAgent(StreamingAgentForInternLM): | |
| def __init__( | |
| self, | |
| user_input_template: str = "{question}", | |
| user_context_template: str = None, | |
| **kwargs, | |
| ): | |
| self.user_input_template = user_input_template | |
| self.user_context_template = user_context_template | |
| super().__init__(**kwargs) | |
| def forward( | |
| self, | |
| question: str, | |
| topic: str, | |
| history: List[dict] = None, | |
| session_id=0, | |
| **kwargs, | |
| ): | |
| message = [self.user_input_template.format(question=question, topic=topic)] | |
| if history and self.user_context_template: | |
| message = [self.user_context_template.format_map(item) for item in history] + message | |
| message = "\n".join(message) | |
| return super().forward(message, session_id=session_id, **kwargs) | |
| class AsyncSearcherAgent(AsyncStreamingAgentForInternLM): | |
| def __init__( | |
| self, | |
| user_input_template: str = "{question}", | |
| user_context_template: str = None, | |
| **kwargs, | |
| ): | |
| self.user_input_template = user_input_template | |
| self.user_context_template = user_context_template | |
| super().__init__(**kwargs) | |
| async def forward( | |
| self, | |
| question: str, | |
| topic: str, | |
| history: List[dict] = None, | |
| session_id=0, | |
| **kwargs, | |
| ): | |
| message = [self.user_input_template.format(question=question, topic=topic)] | |
| if history and self.user_context_template: | |
| message = [self.user_context_template.format_map(item) for item in history] + message | |
| message = "\n".join(message) | |
| async for message in super().forward(message, session_id=session_id, **kwargs): | |
| yield message | |
| class WebSearchGraph: | |
| is_async = False | |
| SEARCHER_CONFIG = {} | |
| _SEARCHER_LOOP = [] | |
| _SEARCHER_THREAD = [] | |
| def __init__(self): | |
| self.nodes: Dict[str, Dict[str, str]] = {} | |
| self.adjacency_list: Dict[str, List[dict]] = defaultdict(list) | |
| self.future_to_query = dict() | |
| self.searcher_resp_queue = queue.Queue() | |
| self.executor = ThreadPoolExecutor(max_workers=10) | |
| self.n_active_tasks = 0 | |
| def add_root_node( | |
| self, | |
| node_content: str, | |
| node_name: str = "root", | |
| ): | |
| """添加起始节点 | |
| Args: | |
| node_content (str): 节点内容 | |
| node_name (str, optional): 节点名称. Defaults to 'root'. | |
| """ | |
| self.nodes[node_name] = dict(content=node_content, type="root") | |
| self.adjacency_list[node_name] = [] | |
| def add_node( | |
| self, | |
| node_name: str, | |
| node_content: str, | |
| ): | |
| """添加搜索子问题节点 | |
| Args: | |
| node_name (str): 节点名称 | |
| node_content (str): 子问题内容 | |
| Returns: | |
| str: 返回搜索结果 | |
| """ | |
| self.nodes[node_name] = dict(content=node_content, type="searcher") | |
| self.adjacency_list[node_name] = [] | |
| parent_nodes = [] | |
| for start_node, adj in self.adjacency_list.items(): | |
| for neighbor in adj: | |
| if ( | |
| node_name == neighbor | |
| and start_node in self.nodes | |
| and "response" in self.nodes[start_node] | |
| ): | |
| parent_nodes.append(self.nodes[start_node]) | |
| parent_response = [ | |
| dict(question=node["content"], answer=node["response"]) for node in parent_nodes | |
| ] | |
| if self.is_async: | |
| async def _async_search_node_stream(): | |
| cfg = { | |
| **self.SEARCHER_CONFIG, | |
| "plugins": deepcopy(self.SEARCHER_CONFIG.get("plugins")), | |
| } | |
| agent, session_id = AsyncSearcherAgent(**cfg), random.randint(0, 999999) | |
| searcher_message = AgentMessage(sender="SearcherAgent", content="") | |
| try: | |
| async for searcher_message in agent( | |
| question=node_content, | |
| topic=self.nodes["root"]["content"], | |
| history=parent_response, | |
| session_id=session_id, | |
| ): | |
| self.nodes[node_name]["response"] = searcher_message.model_dump() | |
| self.nodes[node_name]["memory"] = agent.state_dict(session_id=session_id) | |
| self.nodes[node_name]["session_id"] = session_id | |
| self.searcher_resp_queue.put((node_name, self.nodes[node_name], [])) | |
| self.searcher_resp_queue.put((None, None, None)) | |
| except Exception as exc: | |
| self.searcher_resp_queue.put((exc, None, None)) | |
| self.future_to_query[ | |
| asyncio.run_coroutine_threadsafe( | |
| _async_search_node_stream(), random.choice(self._SEARCHER_LOOP) | |
| ) | |
| ] = f"{node_name}-{node_content}" | |
| # self.future_to_query[ | |
| # self.executor.submit(asyncio.run, _async_search_node_stream()) | |
| # ] = f"{node_name}-{node_content}" | |
| else: | |
| def _search_node_stream(): | |
| cfg = { | |
| **self.SEARCHER_CONFIG, | |
| "plugins": deepcopy(self.SEARCHER_CONFIG.get("plugins")), | |
| } | |
| agent, session_id = SearcherAgent(**cfg), random.randint(0, 999999) | |
| searcher_message = AgentMessage(sender="SearcherAgent", content="") | |
| try: | |
| for searcher_message in agent( | |
| question=node_content, | |
| topic=self.nodes["root"]["content"], | |
| history=parent_response, | |
| session_id=session_id, | |
| ): | |
| self.nodes[node_name]["response"] = searcher_message.model_dump() | |
| self.nodes[node_name]["memory"] = agent.state_dict(session_id=session_id) | |
| self.nodes[node_name]["session_id"] = session_id | |
| self.searcher_resp_queue.put((node_name, self.nodes[node_name], [])) | |
| self.searcher_resp_queue.put((None, None, None)) | |
| except Exception as exc: | |
| self.searcher_resp_queue.put((exc, None, None)) | |
| self.future_to_query[ | |
| self.executor.submit(_search_node_stream) | |
| ] = f"{node_name}-{node_content}" | |
| self.n_active_tasks += 1 | |
| def add_response_node(self, node_name="response"): | |
| """添加回复节点 | |
| Args: | |
| thought (str): 思考过程 | |
| node_name (str, optional): 节点名称. Defaults to 'response'. | |
| """ | |
| self.nodes[node_name] = dict(type="end") | |
| self.searcher_resp_queue.put((node_name, self.nodes[node_name], [])) | |
| def add_edge(self, start_node: str, end_node: str): | |
| """添加边 | |
| Args: | |
| start_node (str): 起始节点名称 | |
| end_node (str): 结束节点名称 | |
| """ | |
| self.adjacency_list[start_node].append(dict(id=str(uuid.uuid4()), name=end_node, state=2)) | |
| self.searcher_resp_queue.put( | |
| (start_node, self.nodes[start_node], self.adjacency_list[start_node]) | |
| ) | |
| def reset(self): | |
| self.nodes = {} | |
| self.adjacency_list = defaultdict(list) | |
| def node(self, node_name: str) -> str: | |
| return self.nodes[node_name].copy() | |
| def start_loop(cls, n: int = 32): | |
| if not cls.is_async: | |
| raise RuntimeError("Event loop cannot be launched as `is_async` is disabled") | |
| assert len(cls._SEARCHER_LOOP) == len(cls._SEARCHER_THREAD) | |
| for i, (loop, thread) in enumerate( | |
| zip(cls._SEARCHER_LOOP.copy(), cls._SEARCHER_THREAD.copy()) | |
| ): | |
| if not (loop.is_running() and thread.is_alive()): | |
| cls._SEARCHER_LOOP.pop(i) | |
| cls._SEARCHER_THREAD.pop(i) | |
| while len(cls._SEARCHER_THREAD) < n: | |
| def _start_loop(): | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| cls._SEARCHER_LOOP.append(loop) | |
| loop.run_forever() | |
| thread = Thread(target=_start_loop, daemon=True) | |
| thread.start() | |
| cls._SEARCHER_THREAD.append(thread) | |
| class ExecutionAction(BaseAction): | |
| """Tool used by MindSearch planner to execute graph node query.""" | |
| def run(self, command, local_dict, global_dict, stream_graph=False): | |
| def extract_code(text: str) -> str: | |
| text = re.sub(r"from ([\w.]+) import WebSearchGraph", "", text) | |
| triple_match = re.search(r"```[^\n]*\n(.+?)```", text, re.DOTALL) | |
| single_match = re.search(r"`([^`]*)`", text, re.DOTALL) | |
| if triple_match: | |
| return triple_match.group(1) | |
| elif single_match: | |
| return single_match.group(1) | |
| return text | |
| command = extract_code(command) | |
| exec(command, global_dict, local_dict) | |
| # 匹配所有 graph.node 中的内容 | |
| node_list = re.findall(r"graph.node\((.*?)\)", command) | |
| graph: WebSearchGraph = local_dict["graph"] | |
| while graph.n_active_tasks: | |
| while not graph.searcher_resp_queue.empty(): | |
| node_name, _, _ = graph.searcher_resp_queue.get(timeout=60) | |
| if isinstance(node_name, Exception): | |
| raise node_name | |
| if node_name is None: | |
| graph.n_active_tasks -= 1 | |
| continue | |
| if stream_graph: | |
| for neighbors in graph.adjacency_list.values(): | |
| for neighbor in neighbors: | |
| # state 1进行中,2未开始,3已结束 | |
| if not ( | |
| neighbor["name"] in graph.nodes | |
| and "response" in graph.nodes[neighbor["name"]] | |
| ): | |
| neighbor["state"] = 2 | |
| elif ( | |
| graph.nodes[neighbor["name"]]["response"]["stream_state"] | |
| == AgentStatusCode.END | |
| ): | |
| neighbor["state"] = 3 | |
| else: | |
| neighbor["state"] = 1 | |
| if all( | |
| "response" in node | |
| for name, node in graph.nodes.items() | |
| if name not in ["root", "response"] | |
| ): | |
| yield AgentMessage( | |
| sender=self.name, | |
| content=dict(current_node=node_name), | |
| formatted=dict( | |
| node=deepcopy(graph.nodes), | |
| adjacency_list=deepcopy(graph.adjacency_list), | |
| ), | |
| stream_state=AgentStatusCode.STREAM_ING, | |
| ) | |
| res = [graph.nodes[node.strip().strip('"').strip("'")] for node in node_list] | |
| return res, graph.nodes, graph.adjacency_list | |