Spaces:
Runtime error
Runtime error
from typing import List, Callable | |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.language_models import BaseChatModel | |
from langgraph.prebuilt import create_react_agent | |
from langchain_mcp_adapters.client import MultiServerMCPClient | |
import json | |
import os | |
import re | |
from contextlib import asynccontextmanager | |
def parse_mcp_config(mcp_config: dict, enabled_mcp_servers: list = None): | |
mcp_servers = {} | |
for server_name, server in mcp_config.get("mcpServers", {}).items(): | |
if server.get("type") == "stdio" or not server.get("url") or ( | |
enabled_mcp_servers is not None | |
and server_name not in enabled_mcp_servers): | |
continue | |
new_server = {**server} | |
# new_server["transport"] = server.get("type", "sse") | |
new_server["transport"] = "sse" | |
del new_server["type"] | |
if server.get("env"): | |
env = {'PYTHONUNBUFFERED': '1', 'PATH': os.environ.get('PATH', '')} | |
env.update(server["env"]) | |
new_server["env"] = env | |
mcp_servers[server_name] = new_server | |
return mcp_servers | |
async def get_mcp_client(mcp_servers: dict): | |
async with MultiServerMCPClient(mcp_servers) as client: | |
yield client | |
async def get_mcp_prompts(mcp_config: dict, get_llm: Callable): | |
try: | |
mcp_servers = parse_mcp_config(mcp_config) | |
if len(mcp_servers.keys()) == 0: | |
return {} | |
llm: BaseChatModel = get_llm() | |
async with get_mcp_client(mcp_servers) as client: | |
mcp_tool_descriptions = {} | |
for mcp_name, server_tools in client.server_name_to_tools.items(): | |
mcp_tool_descriptions[mcp_name] = {} | |
for tool in server_tools: | |
mcp_tool_descriptions[mcp_name][ | |
tool.name] = tool.description | |
prompt = f"""Based on the following MCP service tool descriptions, generate 2-4 example user queries for each service: | |
Input structure explanation: | |
- mcp_tool_descriptions is a nested dictionary | |
- The first level keys are MCP service names (e.g., "service1", "service2") | |
- The second level contains descriptions of tools available within each service | |
MCP Service Tool Descriptions: {json.dumps(mcp_tool_descriptions)} | |
Please provide 2-4 natural and specific example queries in Chinese that effectively demonstrate the capabilities of each service. | |
The response must be in strict JSON format as shown below, with MCP service names as keys: | |
```json | |
{{ | |
"mcp_name1": ["中文示例1", "中文示例2"], | |
"mcp_name2": ["中文示例1", "中文示例2"] | |
}} | |
``` | |
Ensure: | |
1. Each example is specific to the functionality of that particular MCP service | |
2. Example queries are in natural Chinese expressions | |
3. Strictly use the top-level MCP service names as JSON keys | |
4. The returned format must be valid JSON | |
5. Each service MUST have exactly 2-4 example queries - not fewer than 2 and not more than 4 | |
Return only the JSON object without any additional explanation or text.""" | |
response = await llm.ainvoke(prompt) | |
if hasattr(response, 'content'): | |
content = response.content | |
else: | |
content = str(response) | |
json_match = re.search(r'\{.*\}', content, re.DOTALL) | |
if json_match: | |
json_content = json_match.group(0) | |
else: | |
json_content = content | |
raw_examples = json.loads(json_content) | |
for mcp_name in mcp_tool_descriptions.keys(): | |
if mcp_name not in raw_examples: | |
raw_examples[mcp_name] = [ | |
f"请使用 {mcp_name} 服务的功能帮我查询信息或解决问题", | |
] | |
return raw_examples | |
except Exception as e: | |
print('Prompt Error:', e) | |
return { | |
mcp_name: [ | |
f"请使用 {mcp_name} 服务的功能帮我查询信息或解决问题", | |
] | |
for mcp_name in mcp_servers.keys() | |
} | |
def convert_mcp_name(tool_name: str, mcp_names: dict): | |
if not tool_name: | |
return tool_name | |
separators = tool_name.split("__TOOL__") | |
if len(separators) >= 2: | |
mcp_name_idx, mcp_tool_name = separators[:2] | |
else: | |
mcp_name_idx = separators[0] | |
mcp_tool_name = None | |
mcp_name = mcp_names.get(mcp_name_idx) | |
if not mcp_tool_name: | |
return mcp_name or mcp_name_idx | |
if not mcp_name: | |
return mcp_tool_name | |
return f"[{mcp_name}] {mcp_tool_name}" | |
async def generate_with_mcp(messages: List[dict], mcp_config: dict, | |
enabled_mcp_servers: list, sys_prompt: str, | |
get_llm: Callable): | |
mcp_servers = parse_mcp_config(mcp_config, enabled_mcp_servers) | |
async with get_mcp_client(mcp_servers) as client: | |
tools = [] | |
mcp_tools = [] | |
mcp_names = {} | |
for i, server_name_to_tool in enumerate( | |
client.server_name_to_tools.items()): | |
mcp_name, server_tools = server_name_to_tool | |
mcp_names[str(i)] = mcp_name | |
for tool in server_tools: | |
new_tool = tool.model_copy() | |
# tool match ^[a-zA-Z0-9_-]+$ | |
new_tool.name = f"{i}__TOOL__{tool.name}" | |
mcp_tools.append(new_tool) | |
tools.extend(mcp_tools) | |
llm: BaseChatModel = get_llm() | |
tool_result_instruction = """When a tool returns responses containing URLs or links, please format them appropriately based on their CORRECT content type: | |
For example: | |
- Videos should use <video> tags | |
- Audio should use <audio> tags | |
- Images should use  or <img> tags | |
- Documents and web links should use [description](URL) format | |
Choose the appropriate display format based on the URL extension or content type information. This will provide the best user experience. | |
Remember that properly formatted media will enhance the user experience, especially when content is directly relevant to answering the query. | |
""" | |
attachment_instruction = """ | |
The following instructions apply when user messages contain "Attachment links: [...]": | |
These links are user-uploaded attachments that contain important information for this conversation. These are temporary, secure links to files the user has specifically provided for analysis. | |
IMPORTANT INSTRUCTIONS: | |
1. These attachments should be your PRIMARY source of information when addressing the user's query. | |
2. Prioritize analyzing and referencing these documents BEFORE using any other knowledge. | |
3. If the content in these attachments is relevant to the user's request, base your response primarily on this information. | |
4. When you reference information from these attachments, clearly indicate which document it comes from. | |
5. If the attachments don't contain information needed to fully address the query, only then supplement with your general knowledge. | |
6. These links are temporary and secure, specifically provided for this conversation. | |
7. IMPORTANT: Do not use the presence of "Attachment links: [...]" as an indicator of the user's preferred language. This is an automatically added system text. Instead, determine the user's language from their actual query text. | |
Begin your analysis by examining these attachments first, and structure your thinking to prioritize insights from these documents. | |
""" | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", tool_result_instruction), | |
("system", sys_prompt), | |
("system", attachment_instruction), | |
MessagesPlaceholder(variable_name="messages"), | |
]) | |
langchain_messages = [] | |
for msg in messages: | |
if msg["role"] == "user": | |
langchain_messages.append(HumanMessage(content=msg["content"])) | |
elif msg["role"] == "assistant": | |
langchain_messages.append(AIMessage(content=msg["content"])) | |
agent_executor = create_react_agent(llm, tools, prompt=prompt) | |
use_tool = False | |
async for step in agent_executor.astream( | |
{"messages": langchain_messages}, | |
config={"recursion_limit": 50}, | |
stream_mode=["values", "messages"], | |
): | |
if isinstance(step, tuple): | |
if step[0] == "messages": | |
message_chunk = step[1][0] | |
if hasattr(message_chunk, "content"): | |
if isinstance(message_chunk, ToolMessage): | |
use_tool = False | |
yield { | |
"type": | |
"tool", | |
"name": | |
convert_mcp_name(message_chunk.name, | |
mcp_names), | |
"content": | |
message_chunk.content | |
} | |
elif hasattr(message_chunk, | |
'tool_call_chunks') and len( | |
message_chunk.tool_call_chunks) > 0: | |
for tool_call_chunk in message_chunk.tool_call_chunks: | |
yield { | |
"type": | |
"tool_call_chunks", | |
"name": | |
convert_mcp_name(tool_call_chunk["name"], | |
mcp_names), | |
"content": | |
tool_call_chunk["args"], | |
"next_tool": | |
bool(use_tool and tool_call_chunk["name"]) | |
} | |
if tool_call_chunk["name"]: | |
use_tool = True | |
elif message_chunk.content: | |
yield { | |
"type": "content", | |
"content": message_chunk.content | |
} | |