Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import asyncio | |
| import nest_asyncio | |
| import json | |
| import os | |
| import platform | |
| import time | |
| import subprocess | |
| import threading | |
| import signal | |
| import sys | |
| if platform.system() == "Windows": | |
| asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) | |
| # Apply nest_asyncio: Allow nested calls within an already running event loop | |
| nest_asyncio.apply() | |
| # Create and reuse global event loop (create once and continue using) | |
| if "event_loop" not in st.session_state: | |
| loop = asyncio.new_event_loop() | |
| st.session_state.event_loop = loop | |
| asyncio.set_event_loop(loop) | |
| from langgraph.prebuilt import create_react_agent | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage | |
| from dotenv import load_dotenv | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from utils import astream_graph, random_uuid | |
| from langchain_core.messages.ai import AIMessageChunk | |
| from langchain_core.messages.tool import ToolMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.runnables import RunnableConfig | |
| # Load environment variables (get API keys and settings from .env file) | |
| load_dotenv(override=True) | |
| # config.json file path setting | |
| CONFIG_FILE_PATH = "config.json" | |
| # Function to load settings from JSON file | |
| def load_config_from_json(): | |
| """ | |
| Loads settings from config.json file. | |
| Creates a file with default settings if it doesn't exist. | |
| Returns: | |
| dict: Loaded settings | |
| """ | |
| default_config = { | |
| "get_current_time": { | |
| "command": "python", | |
| "args": ["./mcp_server_time.py"], | |
| "transport": "stdio" | |
| } | |
| } | |
| try: | |
| if os.path.exists(CONFIG_FILE_PATH): | |
| with open(CONFIG_FILE_PATH, "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| return config | |
| else: | |
| # Create file with default settings if it doesn't exist | |
| save_config_to_json(default_config) | |
| return default_config | |
| except Exception as e: | |
| st.error(f"Error loading settings file: {str(e)}") | |
| return default_config | |
| # Function to save settings to JSON file | |
| def save_config_to_json(config): | |
| """ | |
| Saves settings to config.json file. | |
| Args: | |
| config (dict): Settings to save | |
| Returns: | |
| bool: Save success status | |
| """ | |
| try: | |
| with open(CONFIG_FILE_PATH, "w", encoding="utf-8") as f: | |
| json.dump(config, f, indent=2, ensure_ascii=False) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving settings file: {str(e)}") | |
| return False | |
| def start_retrieve_service(): | |
| """ | |
| 启动 Retrieve 服务作为后台进程 | |
| """ | |
| try: | |
| # 检查服务是否已经在运行 | |
| if "retrieve_process" in st.session_state and st.session_state.retrieve_process: | |
| try: | |
| # 检查进程是否还在运行 | |
| if st.session_state.retrieve_process.poll() is None: | |
| st.info("✅ Retrieve 服务已经在运行") | |
| return True | |
| except: | |
| pass | |
| # 启动服务 | |
| st.info("🚀 正在启动 Retrieve 服务...") | |
| # 构建命令 - 使用 cwd 参数设置工作目录 | |
| cmd = ["python", "main.py"] | |
| # 启动进程 | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| bufsize=1, | |
| universal_newlines=True, | |
| cwd="python-services/Retrieve" # 设置工作目录 | |
| ) | |
| # 存储进程引用 | |
| st.session_state.retrieve_process = process | |
| st.session_state.retrieve_started = True | |
| # 启动后台线程来监控进程输出 | |
| def monitor_process(): | |
| try: | |
| while process.poll() is None: | |
| # 读取输出 | |
| output = process.stdout.readline() | |
| if output: | |
| st.info(f"Retrieve 服务: {output.strip()}") | |
| # 检查错误输出 | |
| error = process.stderr.readline() | |
| if error: | |
| st.warning(f"Retrieve 服务错误: {error.strip()}") | |
| time.sleep(0.1) | |
| # 进程结束 | |
| st.warning(f"Retrieve 服务已停止,退出码: {process.returncode}") | |
| st.session_state.retrieve_started = False | |
| except Exception as e: | |
| st.error(f"监控 Retrieve 服务时出错: {str(e)}") | |
| # 启动监控线程 | |
| monitor_thread = threading.Thread(target=monitor_process, daemon=True) | |
| monitor_thread.start() | |
| # 等待一下确保服务启动 | |
| time.sleep(2) | |
| # 检查服务是否成功启动 | |
| if process.poll() is None: | |
| st.success("✅ Retrieve 服务启动成功") | |
| return True | |
| else: | |
| st.error("❌ Retrieve 服务启动失败") | |
| return False | |
| except Exception as e: | |
| st.error(f"启动 Retrieve 服务时出错: {str(e)}") | |
| return False | |
| def stop_retrieve_service(): | |
| """ | |
| 停止 Retrieve 服务 | |
| """ | |
| try: | |
| if "retrieve_process" in st.session_state and st.session_state.retrieve_process: | |
| process = st.session_state.retrieve_process | |
| if process.poll() is None: | |
| # 发送终止信号 | |
| process.terminate() | |
| # 等待进程结束 | |
| try: | |
| process.wait(timeout=5) | |
| except subprocess.TimeoutExpired: | |
| # 强制杀死进程 | |
| process.kill() | |
| st.success("✅ Retrieve 服务已停止") | |
| else: | |
| st.info("Retrieve 服务已经停止") | |
| st.session_state.retrieve_started = False | |
| st.session_state.retrieve_process = None | |
| except Exception as e: | |
| st.error(f"停止 Retrieve 服务时出错: {str(e)}") | |
| # Initialize login session variables | |
| if "authenticated" not in st.session_state: | |
| st.session_state.authenticated = False | |
| # Check if login is required | |
| use_login = os.environ.get("USE_LOGIN", "false").lower() == "true" | |
| # Change page settings based on login status | |
| if use_login and not st.session_state.authenticated: | |
| # Login page uses default (narrow) layout | |
| st.set_page_config(page_title="Agent with MCP Tools", page_icon="🧠") | |
| else: | |
| # Main app uses wide layout | |
| st.set_page_config(page_title="Agent with MCP Tools", page_icon="🧠", layout="wide") | |
| # Display login screen if login feature is enabled and not yet authenticated | |
| if use_login and not st.session_state.authenticated: | |
| st.title("🔐 Login") | |
| st.markdown("Login is required to use the system.") | |
| # Place login form in the center of the screen with narrow width | |
| with st.form("login_form"): | |
| username = st.text_input("Username") | |
| password = st.text_input("Password", type="password") | |
| submit_button = st.form_submit_button("Login") | |
| if submit_button: | |
| expected_username = os.environ.get("USER_ID") | |
| expected_password = os.environ.get("USER_PASSWORD") | |
| if username == expected_username and password == expected_password: | |
| st.session_state.authenticated = True | |
| st.success("✅ Login successful! Please wait...") | |
| st.rerun() | |
| else: | |
| st.error("❌ Username or password is incorrect.") | |
| # Don't display the main app on the login screen | |
| st.stop() | |
| # Add author information at the top of the sidebar (placed before other sidebar elements) | |
| st.sidebar.markdown( | |
| "### 🔬 [Automated-DATA-Extractor](https://huggingface.co/spaces/jackkuo/Automated-Enzyme-Kinetics-Extractor)" | |
| ) | |
| st.sidebar.divider() # Add divider | |
| # Existing page title and description | |
| st.title("💬 MCP Tool Utilization Agent") | |
| st.markdown("✨ Ask questions to the ReAct agent that utilizes MCP tools.") | |
| SYSTEM_PROMPT = """<ROLE> | |
| You are a smart agent with an ability to use tools. | |
| You will be given a question and you will use the tools to answer the question. | |
| Pick the most relevant tool to answer the question. | |
| If you are failed to answer the question, try different tools to get context. | |
| Your answer should be very polite and professional. | |
| </ROLE> | |
| ---- | |
| <INSTRUCTIONS> | |
| Step 1: Analyze the question | |
| - Analyze user's question and final goal. | |
| - If the user's question is consist of multiple sub-questions, split them into smaller sub-questions. | |
| Step 2: Pick the most relevant tool | |
| - Pick the most relevant tool to answer the question. | |
| - If you are failed to answer the question, try different tools to get context. | |
| Step 3: Answer the question | |
| - Answer the question in the same language as the question. | |
| - Your answer should be very polite and professional. | |
| Step 4: Provide the source of the answer(if applicable) | |
| - If you've used the tool, provide the source of the answer. | |
| - Valid sources are either a website(URL) or a document(PDF, etc). | |
| Guidelines: | |
| - If you've used the tool, your answer should be based on the tool's output(tool's output is more important than your own knowledge). | |
| - If you've used the tool, and the source is valid URL, provide the source(URL) of the answer. | |
| - Skip providing the source if the source is not URL. | |
| - Answer in the same language as the question. | |
| - Answer should be concise and to the point. | |
| - Avoid response your output with any other information than the answer and the source. | |
| </INSTRUCTIONS> | |
| ---- | |
| <OUTPUT_FORMAT> | |
| (concise answer to the question) | |
| **Source**(if applicable) | |
| - (source1: valid URL) | |
| - (source2: valid URL) | |
| - ... | |
| </OUTPUT_FORMAT> | |
| """ | |
| OUTPUT_TOKEN_INFO = { | |
| "claude-3-5-sonnet-latest": {"max_tokens": 8192}, | |
| "claude-3-5-haiku-latest": {"max_tokens": 8192}, | |
| "claude-3-5-sonnet-20241022": {"max_tokens": 64000}, | |
| "gpt-4o": {"max_tokens": 4096}, # 16000}, | |
| "gpt-4o-mini": {"max_tokens": 16000}, | |
| } | |
| # Initialize session state | |
| if "session_initialized" not in st.session_state: | |
| st.session_state.session_initialized = False # Session initialization flag | |
| st.session_state.agent = None # Storage for ReAct agent object | |
| st.session_state.history = [] # List for storing conversation history | |
| st.session_state.mcp_client = None # Storage for MCP client object | |
| st.session_state.timeout_seconds = ( | |
| 30000 # Response generation time limit (seconds), default 120 seconds | |
| ) | |
| st.session_state.selected_model = ( | |
| "claude-3-5-sonnet-20241022" # Default model selection | |
| ) | |
| st.session_state.recursion_limit = 100 # Recursion call limit, default 100 | |
| if "thread_id" not in st.session_state: | |
| st.session_state.thread_id = random_uuid() | |
| # --- Function Definitions --- | |
| async def cleanup_mcp_client(): | |
| """ | |
| Safely terminates the existing MCP client. | |
| Properly releases resources if an existing client exists. | |
| """ | |
| if "mcp_client" in st.session_state and st.session_state.mcp_client is not None: | |
| try: | |
| # New version doesn't use async context managers, just set to None | |
| st.session_state.mcp_client = None | |
| except Exception as e: | |
| import traceback | |
| # st.warning(f"Error while terminating MCP client: {str(e)}") | |
| # st.warning(traceback.format_exc()) | |
| def print_message(): | |
| """ | |
| Displays chat history on the screen. | |
| Distinguishes between user and assistant messages on the screen, | |
| and displays tool call information within the assistant message container. | |
| """ | |
| i = 0 | |
| while i < len(st.session_state.history): | |
| message = st.session_state.history[i] | |
| if message["role"] == "user": | |
| st.chat_message("user", avatar="🧑💻").markdown(message["content"]) | |
| i += 1 | |
| elif message["role"] == "assistant": | |
| # Create assistant message container | |
| with st.chat_message("assistant", avatar="🤖"): | |
| # Display assistant message content | |
| st.markdown(message["content"]) | |
| # Check if the next message is tool call information | |
| if ( | |
| i + 1 < len(st.session_state.history) | |
| and st.session_state.history[i + 1]["role"] == "assistant_tool" | |
| ): | |
| # Display tool call information in the same container as an expander | |
| with st.expander("🔧 Tool Call Information", expanded=False): | |
| st.markdown(st.session_state.history[i + 1]["content"]) | |
| i += 2 # Increment by 2 as we processed two messages together | |
| else: | |
| i += 1 # Increment by 1 as we only processed a regular message | |
| else: | |
| # Skip assistant_tool messages as they are handled above | |
| i += 1 | |
| def get_streaming_callback(text_placeholder, tool_placeholder): | |
| """ | |
| Creates a streaming callback function. | |
| This function creates a callback function to display responses generated from the LLM in real-time. | |
| It displays text responses and tool call information in separate areas. | |
| It also supports real-time streaming updates from MCP tools. | |
| Args: | |
| text_placeholder: Streamlit component to display text responses | |
| tool_placeholder: Streamlit component to display tool call information | |
| Returns: | |
| callback_func: Streaming callback function | |
| accumulated_text: List to store accumulated text responses | |
| accumulated_tool: List to store accumulated tool call information | |
| """ | |
| accumulated_text = [] | |
| accumulated_tool = [] | |
| def callback_func(message: dict): | |
| nonlocal accumulated_text, accumulated_tool | |
| message_content = message.get("content", None) | |
| # Initialize data counter for tracking data: messages | |
| if not hasattr(callback_func, '_data_counter'): | |
| callback_func._data_counter = 0 | |
| # Initialize tool result tracking | |
| if not hasattr(callback_func, '_tool_results'): | |
| callback_func._tool_results = {} | |
| # Check if this is a tool result message | |
| if isinstance(message_content, dict) and 'tool_results' in message_content: | |
| tool_results = message_content['tool_results'] | |
| for tool_name, result in tool_results.items(): | |
| callback_func._tool_results[tool_name] = result | |
| # Check if this is a tool call completion message | |
| if isinstance(message_content, dict) and 'tool_calls' in message_content: | |
| tool_calls = message_content['tool_calls'] | |
| for tool_call in tool_calls: | |
| if isinstance(tool_call, dict) and 'name' in tool_call: | |
| tool_name = tool_call['name'] | |
| if 'result' in tool_call: | |
| # Store tool result | |
| callback_func._tool_results[tool_name] = tool_call['result'] | |
| # Handle different message types | |
| if isinstance(message_content, AIMessageChunk): | |
| # Process AIMessageChunk content | |
| content = message_content.content | |
| # If content is in list form (mainly occurs in Claude models) | |
| if isinstance(content, list) and len(content) > 0: | |
| message_chunk = content[0] | |
| # Process text type | |
| if message_chunk["type"] == "text": | |
| accumulated_text.append(message_chunk["text"]) | |
| text_placeholder.markdown("".join(accumulated_text)) | |
| # Process tool use type | |
| elif message_chunk["type"] == "tool_use": | |
| if "partial_json" in message_chunk: | |
| accumulated_tool.append(message_chunk["partial_json"]) | |
| else: | |
| tool_call_chunks = message_content.tool_call_chunks | |
| tool_call_chunk = tool_call_chunks[0] | |
| accumulated_tool.append( | |
| "\n```json\n" + str(tool_call_chunk) + "\n```\n" | |
| ) | |
| with tool_placeholder.expander( | |
| "🔧 Tool Call Information", expanded=True | |
| ): | |
| st.markdown("".join(accumulated_tool)) | |
| # Process if tool_calls attribute exists (mainly occurs in OpenAI models) | |
| elif ( | |
| hasattr(message_content, "tool_calls") | |
| and message_content.tool_calls | |
| and len(message_content.tool_calls[0]["name"]) > 0 | |
| ): | |
| tool_call_info = message_content.tool_calls[0] | |
| accumulated_tool.append("\n```json\n" + str(tool_call_info) + "\n```\n") | |
| with tool_placeholder.expander( | |
| "🔧 Tool Call Information", expanded=True | |
| ): | |
| st.markdown("".join(accumulated_tool)) | |
| # Process if content is a simple string | |
| elif isinstance(content, str): | |
| # Regular text content | |
| accumulated_text.append(content) | |
| text_placeholder.markdown("".join(accumulated_text)) | |
| # Process if invalid tool call information exists | |
| elif ( | |
| hasattr(message_content, "invalid_tool_calls") | |
| and message_content.invalid_tool_calls | |
| ): | |
| tool_call_info = message_content.invalid_tool_calls[0] | |
| accumulated_tool.append("\n```json\n" + str(tool_call_info) + "\n```\n") | |
| with tool_placeholder.expander("🔧 Tool Call Information (Invalid)", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| # Process if tool_call_chunks attribute exists | |
| elif ( | |
| hasattr(message_content, "tool_call_chunks") | |
| and message_content.tool_call_chunks | |
| ): | |
| tool_call_chunk = message_content.tool_call_chunks[0] | |
| tool_name = tool_call_chunk.get('name', 'Unknown') | |
| # Only show tool call info if it's a new tool or has meaningful changes | |
| if not hasattr(callback_func, '_last_tool_name') or callback_func._last_tool_name != tool_name: | |
| accumulated_tool.append( | |
| f"\n🔧 **Tool Call**: {tool_name}\n" | |
| ) | |
| callback_func._last_tool_name = tool_name | |
| # Show tool call details in a more compact format | |
| accumulated_tool.append( | |
| f"```json\n{str(tool_call_chunk)}\n```\n" | |
| ) | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| # Process if tool_calls exists in additional_kwargs (supports various model compatibility) | |
| elif ( | |
| hasattr(message_content, "additional_kwargs") | |
| and "tool_calls" in message_content.additional_kwargs | |
| ): | |
| tool_call_info = message_content.additional_kwargs["tool_calls"][0] | |
| accumulated_tool.append("\n```json\n" + str(tool_call_info) + "\n```\n") | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| # Process if it's a tool message (tool response) | |
| elif isinstance(message_content, ToolMessage): | |
| # Don't show Tool Completed immediately - wait for all streaming content | |
| # Just store the tool name for later display | |
| if not hasattr(callback_func, '_pending_tool_completion'): | |
| callback_func._pending_tool_completion = [] | |
| tool_name = message_content.name or "Unknown Tool" | |
| callback_func._pending_tool_completion.append(tool_name) | |
| # Debug: Log tool message received | |
| accumulated_tool.append(f"\n🔍 **Tool Message Received**: {tool_name}\n") | |
| accumulated_tool.append(f"📋 **Message Type**: {type(message_content).__name__}\n") | |
| # Convert streaming text to final result | |
| streaming_text_items = [item for item in accumulated_tool if item.startswith("\n📊 **Streaming Text**:")] | |
| if streaming_text_items: | |
| # Get the last streaming text (most complete) | |
| last_streaming = streaming_text_items[-1] | |
| # Extract the text content | |
| final_text = last_streaming.replace("\n📊 **Streaming Text**: ", "").strip() | |
| if final_text: | |
| # Remove all streaming text entries | |
| accumulated_tool = [item for item in accumulated_tool if not item.startswith("\n📊 **Streaming Text**:")] | |
| # Add the final complete result | |
| accumulated_tool.append(f"\n📊 **Final Result**: {final_text}\n") | |
| # Handle tool response content | |
| tool_content = message_content.content | |
| # Debug: Log tool content | |
| accumulated_tool.append(f"📄 **Tool Content Type**: {type(tool_content).__name__}\n") | |
| if isinstance(tool_content, str): | |
| accumulated_tool.append(f"📏 **Content Length**: {len(tool_content)} characters\n") | |
| if len(tool_content) > 100: | |
| accumulated_tool.append(f"📝 **Content Preview**: {tool_content[:100]}...\n") | |
| else: | |
| accumulated_tool.append(f"📝 **Content**: {tool_content}\n") | |
| else: | |
| accumulated_tool.append(f"📝 **Content**: {str(tool_content)[:200]}...\n") | |
| # Handle tool response content | |
| if isinstance(tool_content, str): | |
| # Look for SSE data patterns | |
| if "data:" in tool_content: | |
| # Parse SSE data and extract meaningful content | |
| lines = tool_content.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line.startswith('data:'): | |
| # Increment data counter for each data: message | |
| callback_func._data_counter += 1 | |
| try: | |
| # Extract JSON content from SSE data | |
| json_str = line[5:].strip() # Remove 'data:' prefix | |
| if json_str: | |
| # Try to parse as JSON | |
| import json | |
| try: | |
| data_obj = json.loads(json_str) | |
| if isinstance(data_obj, dict): | |
| # Handle different types of SSE data | |
| if data_obj.get("type") == "result": | |
| content = data_obj.get("content", "") | |
| if content: | |
| # Check for specific server output formats | |
| if "```bdd-long-task-start" in content: | |
| # Extract task info | |
| import re | |
| match = re.search(r'```bdd-long-task-start\s*\n(.*?)\n```', content, re.DOTALL) | |
| if match: | |
| try: | |
| task_info = json.loads(match.group(1)) | |
| task_id = task_info.get('id', 'Unknown') | |
| task_label = task_info.get('label', 'Unknown task') | |
| accumulated_tool.append(f"\n🚀 **Task Started** [{task_id}]: {task_label}\n") | |
| except: | |
| accumulated_tool.append(f"\n🚀 **Task Started**: {content}\n") | |
| # Real-time UI update for task start | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| elif "```bdd-long-task-end" in content: | |
| # Extract task info | |
| import re | |
| match = re.search(r'```bdd-long-task-end\s*\n(.*?)\n```', content, re.DOTALL) | |
| if match: | |
| try: | |
| task_info = json.loads(match.group(1)) | |
| task_id = task_info.get('id', 'Unknown') | |
| accumulated_tool.append(f"\n✅ **Task Completed** [{task_id}]\n") | |
| except: | |
| accumulated_tool.append(f"\n✅ **Task Completed**: {content}\n") | |
| # Real-time UI update for task completion | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| elif "```bdd-resource-lookup" in content: | |
| # Extract resource info | |
| import re | |
| match = re.search(r'```bdd-resource-lookup\s*\n(.*?)\n```', content, re.DOTALL) | |
| if match: | |
| try: | |
| resources = json.loads(match.group(1)) | |
| if isinstance(resources, list): | |
| accumulated_tool.append(f"\n📚 **Resources Found**: {len(resources)} items\n") | |
| for i, resource in enumerate(resources[:3]): # Show first 3 | |
| source = resource.get('source', 'Unknown') | |
| doc_id = resource.get('docId', 'Unknown') | |
| citation = resource.get('citation', '') | |
| accumulated_tool.append(f" - {source}: {doc_id} [citation:{citation}]\n") | |
| if len(resources) > 3: | |
| accumulated_tool.append(f" ... and {len(resources) - 3} more\n") | |
| except: | |
| accumulated_tool.append(f"\n📚 **Resources**: {content}\n") | |
| # Real-time UI update for resources | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| elif "```bdd-chat-agent-task" in content: | |
| # Extract chat agent task info | |
| import re | |
| match = re.search(r'```bdd-chat-agent-task\s*\n(.*?)\n```', content, re.DOTALL) | |
| if match: | |
| try: | |
| task_info = json.loads(match.group(1)) | |
| task_type = task_info.get('type', 'Unknown') | |
| task_label = task_info.get('label', 'Unknown') | |
| task_status = task_info.get('status', 'Unknown') | |
| accumulated_tool.append(f"\n🤖 **Agent Task** [{task_status}]: {task_type} - {task_label}\n") | |
| except: | |
| accumulated_tool.append(f"\n🤖 **Agent Task**: {content}\n") | |
| elif "ping - " in content: | |
| # Extract timestamp from ping messages | |
| timestamp = content.split("ping - ")[-1] | |
| accumulated_tool.append(f"⏱️ **Progress Update**: {timestamp}\n") | |
| elif data_obj.get("type") == "done": | |
| # Task completion | |
| accumulated_tool.append(f"\n🎯 **Task Done**: {content}\n") | |
| else: | |
| # Regular result content - accumulate text for better readability | |
| if not hasattr(callback_func, '_result_buffer'): | |
| callback_func._result_buffer = "" | |
| callback_func._result_buffer += content | |
| # For simple text streams (like health check or mock mock), update more frequently | |
| # Check if this is a simple text response (not BDD format) | |
| is_simple_text = not any(marker in content for marker in ['```bdd-', 'ping -', 'data:']) | |
| # For simple text streams, always update immediately to show all fragments | |
| if is_simple_text and content.strip(): | |
| # Clear previous streaming text entries and add updated one | |
| accumulated_tool = [item for item in accumulated_tool if not item.startswith("\n📊 **Streaming Text**:")] | |
| # Add the updated complete streaming text in one line | |
| accumulated_tool.append(f"\n📊 **Streaming Text**: {callback_func._result_buffer}\n") | |
| # Immediate UI update for text streams | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| else: | |
| # For complex content, use timed updates | |
| update_interval = 0.2 if len(content.strip()) <= 10 else 0.5 | |
| # Only update display periodically to avoid excessive updates | |
| if not hasattr(callback_func, '_last_update_time'): | |
| callback_func._last_update_time = 0 | |
| import time | |
| current_time = time.time() | |
| if current_time - callback_func._last_update_time > update_interval: | |
| # For complex content, show accumulated buffer | |
| accumulated_tool.append(f"\n📊 **Result Update**:\n") | |
| accumulated_tool.append(f"```\n{callback_func._result_buffer}\n```\n") | |
| callback_func._last_update_time = current_time | |
| # Real-time UI update | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| else: | |
| # Handle other data types that are not "result" type | |
| # This ensures ALL data: messages are processed and displayed | |
| data_type = data_obj.get("type", "unknown") | |
| data_content = data_obj.get("content", str(data_obj)) | |
| # Add timestamp for real-time tracking | |
| import time | |
| timestamp = time.strftime("%H:%M:%S") | |
| # Format the data for display | |
| if isinstance(data_content, str): | |
| accumulated_tool.append(f"\n📡 **Data [{data_type}]** [{timestamp}]: {data_content}\n") | |
| else: | |
| accumulated_tool.append(f"\n📡 **Data [{data_type}]** [{timestamp}]:\n```json\n{json.dumps(data_obj, indent=2)}\n```\n") | |
| # Immediate real-time UI update for any data: message | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| else: | |
| # Handle non-dict data objects | |
| import time | |
| timestamp = time.strftime("%H:%M:%S") | |
| accumulated_tool.append(f"\n📡 **Raw Data** [{timestamp}]:\n```json\n{json_str}\n```\n") | |
| # Immediate real-time UI update | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| except json.JSONDecodeError: | |
| # If not valid JSON, check if it's streaming text content | |
| if json_str and len(json_str.strip()) > 0: | |
| # This might be streaming text, accumulate it | |
| if not hasattr(callback_func, '_stream_buffer'): | |
| callback_func._stream_buffer = "" | |
| callback_func._stream_buffer += json_str | |
| # Only show streaming content periodically | |
| if not hasattr(callback_func, '_stream_update_time'): | |
| callback_func._stream_update_time = 0 | |
| import time | |
| current_time = time.time() | |
| if current_time - callback_func._stream_update_time > 0.3: # Update every 0.3 seconds for better responsiveness | |
| # Add new streaming update without clearing previous ones | |
| if callback_func._stream_buffer.strip(): | |
| accumulated_tool.append(f"\n📝 **Streaming Update**: {callback_func._stream_buffer}\n") | |
| callback_func._stream_update_time = current_time | |
| # Real-time UI update | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| except Exception: | |
| # Fallback: treat as plain text, but only if it's meaningful | |
| if line.strip() and len(line.strip()) > 1: # Only show non-trivial content | |
| accumulated_tool.append(f"\n📝 **Info**: {line.strip()}\n") | |
| elif line.startswith('ping - '): | |
| # Handle ping messages directly | |
| timestamp = line.split('ping - ')[-1] | |
| accumulated_tool.append(f"⏱️ **Progress Update**: {timestamp}\n") | |
| # Immediate real-time UI update for ping messages | |
| with tool_placeholder.expander("🔧 Tool Call Information", expanded=True): | |
| st.markdown("".join(accumulated_tool)) | |
| elif line and not line.startswith(':'): | |
| # Other non-empty lines | |
| accumulated_tool.append(f"\n📝 **Info**: {line.strip()}\n") | |
| else: | |
| # Regular tool response content | |
| accumulated_tool.append( | |
| "\n```json\n" + str(tool_content) + "\n```\n" | |
| ) | |
| # Capture tool result for display | |
| if hasattr(callback_func, '_pending_tool_completion') and callback_func._pending_tool_completion: | |
| # Get the last completed tool name | |
| last_tool_name = callback_func._pending_tool_completion[-1] if callback_func._pending_tool_completion else "Unknown Tool" | |
| # Store the tool result | |
| if not hasattr(callback_func, '_tool_results'): | |
| callback_func._tool_results = {} | |
| callback_func._tool_results[last_tool_name] = tool_content | |
| # Create tool result for display | |
| callback_func._last_tool_result = { | |
| 'name': last_tool_name, | |
| 'output': tool_content | |
| } | |
| else: | |
| # Non-string content | |
| accumulated_tool.append( | |
| "\n```json\n" + str(tool_content) + "\n```\n" | |
| ) | |
| # Capture tool result for non-string content too | |
| if hasattr(callback_func, '_pending_tool_completion') and callback_func._pending_tool_completion: | |
| last_tool_name = callback_func._pending_tool_completion[-1] if callback_func._pending_tool_completion else "Unknown Tool" | |
| if not hasattr(callback_func, '_tool_results'): | |
| callback_func._tool_results = {} | |
| callback_func._tool_results[last_tool_name] = tool_content | |
| callback_func._last_tool_result = { | |
| 'name': last_tool_name, | |
| 'output': tool_content | |
| } | |
| # Show pending tool completion status after all streaming content | |
| if hasattr(callback_func, '_pending_tool_completion') and callback_func._pending_tool_completion: | |
| for tool_name in callback_func._pending_tool_completion: | |
| accumulated_tool.append(f"\n✅ **Tool Completed**: {tool_name}\n") | |
| # Check if we have a result for this tool | |
| if hasattr(callback_func, '_tool_results') and tool_name in callback_func._tool_results: | |
| tool_result = callback_func._tool_results[tool_name] | |
| callback_func._last_tool_result = { | |
| 'name': tool_name, | |
| 'output': tool_result | |
| } | |
| accumulated_tool.append(f"📊 **Tool Output Captured**: {len(str(tool_result))} characters\n") | |
| else: | |
| accumulated_tool.append(f"⚠️ **No Tool Output Captured** for {tool_name}\n") | |
| # Try to create a basic result structure | |
| callback_func._last_tool_result = { | |
| 'name': tool_name, | |
| 'output': f"Tool {tool_name} completed but output was not captured" | |
| } | |
| # Clear the pending list | |
| callback_func._pending_tool_completion = [] | |
| # Enhanced tool result display for MCP tools | |
| if hasattr(callback_func, '_last_tool_result') and callback_func._last_tool_result: | |
| tool_result = callback_func._last_tool_result | |
| if isinstance(tool_result, dict): | |
| # Extract tool name and result | |
| tool_name = tool_result.get('name', 'Unknown Tool') | |
| tool_output = tool_result.get('output', tool_result.get('result', tool_result.get('content', str(tool_result)))) | |
| accumulated_tool.append(f"\n🔧 **Tool Result - {tool_name}**:\n") | |
| if isinstance(tool_output, str) and tool_output.strip(): | |
| # Format the output nicely | |
| if len(tool_output) > 200: | |
| accumulated_tool.append(f"```\n{tool_output[:200]}...\n```\n") | |
| accumulated_tool.append(f"📏 *Output truncated. Full length: {len(tool_output)} characters*\n") | |
| else: | |
| accumulated_tool.append(f"```\n{tool_output}\n```\n") | |
| else: | |
| accumulated_tool.append(f"```json\n{tool_output}\n```\n") | |
| else: | |
| accumulated_tool.append(f"\n🔧 **Tool Result**:\n```\n{str(tool_result)}\n```\n") | |
| # Clear the tool result after displaying | |
| callback_func._last_tool_result = None | |
| # Return the callback function and accumulated lists | |
| return callback_func, accumulated_text, accumulated_tool | |
| async def process_query(query, text_placeholder, tool_placeholder, timeout_seconds=60): | |
| """ | |
| Processes user questions and generates responses. | |
| This function passes the user's question to the agent and streams the response in real-time. | |
| Returns a timeout error if the response is not completed within the specified time. | |
| Args: | |
| query: Text of the question entered by the user | |
| text_placeholder: Streamlit component to display text responses | |
| tool_placeholder: Streamlit component to display tool call information | |
| timeout_seconds: Response generation time limit (seconds) | |
| Returns: | |
| response: Agent's response object | |
| final_text: Final text response | |
| final_tool: Final tool call information | |
| """ | |
| try: | |
| if st.session_state.agent: | |
| streaming_callback, accumulated_text_obj, accumulated_tool_obj = ( | |
| get_streaming_callback(text_placeholder, tool_placeholder) | |
| ) | |
| try: | |
| response = await asyncio.wait_for( | |
| astream_graph( | |
| st.session_state.agent, | |
| {"messages": [HumanMessage(content=query)]}, | |
| callback=streaming_callback, | |
| config=RunnableConfig( | |
| recursion_limit=st.session_state.recursion_limit, | |
| thread_id=st.session_state.thread_id, | |
| ), | |
| ), | |
| timeout=timeout_seconds, | |
| ) | |
| except asyncio.TimeoutError: | |
| # On timeout, reset thread to avoid leaving an incomplete tool call in memory | |
| st.session_state.thread_id = random_uuid() | |
| error_msg = ( | |
| f"⏱️ Request time exceeded {timeout_seconds} seconds. Conversation was reset. Please retry." | |
| ) | |
| return {"error": error_msg}, error_msg, "" | |
| except ValueError as e: | |
| # Handle invalid chat history caused by incomplete tool calls | |
| if "Found AIMessages with tool_calls" in str(e): | |
| # Reset thread and retry once | |
| st.session_state.thread_id = random_uuid() | |
| try: | |
| response = await asyncio.wait_for( | |
| astream_graph( | |
| st.session_state.agent, | |
| {"messages": [HumanMessage(content=query)]}, | |
| callback=streaming_callback, | |
| config=RunnableConfig( | |
| recursion_limit=st.session_state.recursion_limit, | |
| thread_id=st.session_state.thread_id, | |
| ), | |
| ), | |
| timeout=timeout_seconds, | |
| ) | |
| except Exception: | |
| error_msg = ( | |
| "⚠️ Conversation state was invalid and has been reset. Please try again." | |
| ) | |
| return {"error": error_msg}, error_msg, "" | |
| else: | |
| raise | |
| final_text = "".join(accumulated_text_obj) | |
| final_tool = "".join(accumulated_tool_obj) | |
| return response, final_text, final_tool | |
| else: | |
| return ( | |
| {"error": "🚫 Agent has not been initialized."}, | |
| "🚫 Agent has not been initialized.", | |
| "", | |
| ) | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Error occurred during query processing: {str(e)}\n{traceback.format_exc()}" | |
| return {"error": error_msg}, error_msg, "" | |
| async def initialize_session(mcp_config=None): | |
| """ | |
| Initializes MCP session and agent. | |
| Args: | |
| mcp_config: MCP tool configuration information (JSON). Uses default settings if None | |
| Returns: | |
| bool: Initialization success status | |
| """ | |
| with st.spinner("🔄 Connecting to MCP server..."): | |
| # First safely clean up existing client | |
| await cleanup_mcp_client() | |
| if mcp_config is None: | |
| # Load settings from config.json file | |
| mcp_config = load_config_from_json() | |
| # 自动启动 Retrieve 服务(如果配置中存在) | |
| if "bio-qa-chat" in mcp_config: | |
| st.info("🚀 检测到 bio-qa-chat 服务,正在启动...") | |
| if start_retrieve_service(): | |
| st.success("✅ Retrieve 服务启动成功") | |
| else: | |
| st.warning("⚠️ Retrieve 服务启动失败,但继续初始化其他服务") | |
| # Validate MCP configuration before connecting | |
| st.info("🔍 Validating MCP server configurations...") | |
| config_errors = [] | |
| for server_name, server_config in mcp_config.items(): | |
| st.write(f"📋 Checking {server_name}...") | |
| # Check required fields | |
| if "transport" not in server_config: | |
| config_errors.append(f"{server_name}: Missing 'transport' field") | |
| st.error(f"❌ {server_name}: Missing 'transport' field") | |
| elif server_config["transport"] not in ["stdio", "sse", "http", "streamable_http", "websocket"]: | |
| config_errors.append(f"{server_name}: Invalid transport '{server_config['transport']}'") | |
| st.error(f"❌ {server_name}: Invalid transport '{server_config['transport']}'") | |
| if "url" in server_config: | |
| if "transport" in server_config and server_config["transport"] == "stdio": | |
| config_errors.append(f"{server_name}: Cannot use 'stdio' transport with URL") | |
| st.error(f"❌ {server_name}: Cannot use 'stdio' transport with URL") | |
| elif "command" not in server_config: | |
| config_errors.append(f"{server_name}: Missing 'command' field for stdio transport") | |
| st.error(f"❌ {server_name}: Missing 'command' field for stdio transport") | |
| elif "args" not in server_config: | |
| config_errors.append(f"{server_name}: Missing 'args' field for stdio transport") | |
| st.error(f"❌ {server_name}: Missing 'args' field for stdio transport") | |
| if config_errors: | |
| st.error("🚫 Configuration validation failed!") | |
| st.error("Please fix the following issues:") | |
| for error in config_errors: | |
| st.error(f" • {error}") | |
| return False | |
| st.success("✅ MCP configuration validation passed!") | |
| client = MultiServerMCPClient(mcp_config) | |
| # Get tools with error handling for malformed schemas | |
| try: | |
| tools = await client.get_tools() | |
| st.session_state.tool_count = len(tools) | |
| st.success(f"✅ Successfully loaded {len(tools)} tools from all MCP servers") | |
| except Exception as e: | |
| st.error(f"❌ Error loading MCP tools: {str(e)}") | |
| st.error(f"🔍 Error type: {type(e).__name__}") | |
| st.error(f"📋 Full error details: {repr(e)}") | |
| st.warning("🔄 Attempting to load tools individually to identify problematic servers...") | |
| # Try to load tools from each server individually | |
| tools = [] | |
| failed_servers = [] | |
| for server_name, server_config in mcp_config.items(): | |
| try: | |
| st.info(f"🔄 Testing connection to {server_name}...") | |
| st.json(server_config) # Show server configuration | |
| # Create a single server client to test | |
| single_client = MultiServerMCPClient({server_name: server_config}) | |
| server_tools = await single_client.get_tools() | |
| tools.extend(server_tools) | |
| st.success(f"✅ Loaded {len(server_tools)} tools from {server_name}") | |
| except Exception as server_error: | |
| error_msg = f"❌ Failed to load tools from {server_name}" | |
| st.error(error_msg) | |
| st.error(f" Error: {str(server_error)}") | |
| st.error(f" Type: {type(server_error).__name__}") | |
| st.error(f" Details: {repr(server_error)}") | |
| failed_servers.append(server_name) | |
| continue | |
| # Summary of results | |
| if failed_servers: | |
| st.error(f"🚫 Failed servers: {', '.join(failed_servers)}") | |
| st.error("💡 Check server configurations and ensure servers are running") | |
| if not tools: | |
| st.error("❌ No tools could be loaded from any MCP server. Please check your server configurations.") | |
| st.error("🔧 Troubleshooting tips:") | |
| st.error(" 1. Ensure all MCP servers are running") | |
| st.error(" 2. Check network connectivity and ports") | |
| st.error(" 3. Verify server configurations in config.json") | |
| st.error(" 4. Check server logs for errors") | |
| return False | |
| else: | |
| st.success(f"✅ Successfully loaded {len(tools)} tools from working servers") | |
| st.warning(f"⚠️ Some servers failed: {', '.join(failed_servers)}" if failed_servers else "✅ All servers loaded successfully") | |
| st.session_state.mcp_client = client | |
| # Validate and filter tools to remove malformed schemas | |
| def validate_tool(tool): | |
| try: | |
| # Try to access the tool's schema to validate it | |
| if hasattr(tool, 'schema'): | |
| # This will trigger schema validation | |
| _ = tool.schema | |
| # Additional validation: check if tool can be converted to OpenAI format | |
| # This catches the FileData reference issue | |
| try: | |
| from langchain_core.utils.function_calling import convert_to_openai_tool | |
| _ = convert_to_openai_tool(tool) | |
| return True | |
| except Exception as schema_error: | |
| if "FileData" in str(schema_error) or "Reference" in str(schema_error): | |
| st.warning(f"⚠️ Tool '{getattr(tool, 'name', 'unknown')}' has malformed schema: {str(schema_error)}") | |
| return False | |
| except Exception as e: | |
| st.warning(f"⚠️ Tool '{getattr(tool, 'name', 'unknown')}' validation failed: {str(e)}") | |
| return False | |
| # Filter out invalid tools | |
| valid_tools = [tool for tool in tools if validate_tool(tool)] | |
| if len(valid_tools) < len(tools): | |
| st.warning(f"⚠️ Filtered out {len(tools) - len(valid_tools)} tools with malformed schemas") | |
| tools = valid_tools | |
| st.session_state.tool_count = len(tools) | |
| # Ensure we have at least some valid tools | |
| if not tools: | |
| st.error("❌ No valid tools could be loaded. Please check your MCP server configurations.") | |
| return False | |
| # Initialize appropriate model based on selection | |
| selected_model = st.session_state.selected_model | |
| if selected_model in [ | |
| "claude-3-5-sonnet-20241022", | |
| "claude-3-5-sonnet-latest", | |
| "claude-3-5-haiku-latest", | |
| ]: | |
| model = ChatAnthropic( | |
| model=selected_model, | |
| temperature=0.1, | |
| max_tokens=OUTPUT_TOKEN_INFO[selected_model]["max_tokens"], | |
| ) | |
| else: # Use OpenAI model | |
| model = ChatOpenAI( | |
| base_url=os.environ.get("OPENAI_API_BASE"), | |
| model=selected_model, | |
| temperature=0.1, | |
| max_tokens=OUTPUT_TOKEN_INFO[selected_model]["max_tokens"], | |
| ) | |
| # Create agent with error handling | |
| try: | |
| agent = create_react_agent( | |
| model, | |
| tools, | |
| checkpointer=MemorySaver(), | |
| prompt=SYSTEM_PROMPT, | |
| ) | |
| except Exception as agent_error: | |
| st.error(f"❌ Failed to create agent: {str(agent_error)}") | |
| st.warning("🔄 Attempting to create agent with individual tool validation...") | |
| # Try to create agent with tools one by one | |
| working_tools = [] | |
| for i, tool in enumerate(tools): | |
| try: | |
| test_agent = create_react_agent( | |
| model, | |
| [tool], | |
| checkpointer=MemorySaver(), | |
| prompt=SYSTEM_PROMPT, | |
| ) | |
| working_tools.append(tool) | |
| st.success(f"✅ Tool {i+1} validated successfully") | |
| except Exception as tool_error: | |
| st.error(f"❌ Tool {i+1} failed validation: {str(tool_error)}") | |
| continue | |
| if not working_tools: | |
| st.error("❌ No tools could be used to create the agent. Please check your MCP server configurations.") | |
| return False | |
| # Create agent with only working tools | |
| tools = working_tools | |
| st.session_state.tool_count = len(tools) | |
| agent = create_react_agent( | |
| model, | |
| tools, | |
| checkpointer=MemorySaver(), | |
| prompt=SYSTEM_PROMPT, | |
| ) | |
| st.success(f"✅ Agent created successfully with {len(tools)} working tools") | |
| st.session_state.agent = agent | |
| st.session_state.session_initialized = True | |
| return True | |
| # --- Sidebar: System Settings Section --- | |
| with st.sidebar: | |
| st.subheader("⚙️ System Settings") | |
| # Model selection feature | |
| # Create list of available models | |
| available_models = [] | |
| # Check Anthropic API key | |
| has_anthropic_key = os.environ.get("ANTHROPIC_API_KEY") is not None | |
| if has_anthropic_key: | |
| available_models.extend( | |
| [ | |
| "claude-3-5-sonnet-20241022", | |
| "claude-3-5-sonnet-latest", | |
| "claude-3-5-haiku-latest", | |
| ] | |
| ) | |
| # Check OpenAI API key | |
| has_openai_key = os.environ.get("OPENAI_API_KEY") is not None | |
| if has_openai_key: | |
| available_models.extend(["gpt-4o", "gpt-4o-mini"]) | |
| # Display message if no models are available | |
| if not available_models: | |
| st.warning( | |
| "⚠️ API keys are not configured. Please add ANTHROPIC_API_KEY or OPENAI_API_KEY to your .env file." | |
| ) | |
| # Add Claude model as default (to show UI even without keys) | |
| available_models = ["claude-3-5-sonnet-20241022"] | |
| # Model selection dropdown | |
| previous_model = st.session_state.selected_model | |
| st.session_state.selected_model = st.selectbox( | |
| "🤖 Select model to use", | |
| options=available_models, | |
| index=( | |
| available_models.index(st.session_state.selected_model) | |
| if st.session_state.selected_model in available_models | |
| else 0 | |
| ), | |
| help="Anthropic models require ANTHROPIC_API_KEY and OpenAI models require OPENAI_API_KEY to be set as environment variables.", | |
| ) | |
| # Notify when model is changed and session needs to be reinitialized | |
| if ( | |
| previous_model != st.session_state.selected_model | |
| and st.session_state.session_initialized | |
| ): | |
| st.warning( | |
| "⚠️ Model has been changed. Click 'Apply Settings' button to apply changes." | |
| ) | |
| # Add timeout setting slider | |
| st.session_state.timeout_seconds = st.slider( | |
| "⏱️ Response generation time limit (seconds)", | |
| min_value=60, | |
| max_value=300000, | |
| value=st.session_state.timeout_seconds, | |
| step=10, | |
| help="Set the maximum time for the agent to generate a response. Complex tasks may require more time.", | |
| ) | |
| st.session_state.recursion_limit = st.slider( | |
| "⏱️ Recursion call limit (count)", | |
| min_value=10, | |
| max_value=200, | |
| value=st.session_state.recursion_limit, | |
| step=10, | |
| help="Set the recursion call limit. Setting too high a value may cause memory issues.", | |
| ) | |
| st.divider() # Add divider | |
| # Tool settings section | |
| st.subheader("🔧 Tool Settings") | |
| # Manage expander state in session state | |
| if "mcp_tools_expander" not in st.session_state: | |
| st.session_state.mcp_tools_expander = False | |
| # MCP tool addition interface | |
| with st.expander("🧰 Add MCP Tools", expanded=st.session_state.mcp_tools_expander): | |
| # Load settings from config.json file | |
| loaded_config = load_config_from_json() | |
| default_config_text = json.dumps(loaded_config, indent=2, ensure_ascii=False) | |
| # Create pending config based on existing mcp_config_text if not present | |
| if "pending_mcp_config" not in st.session_state: | |
| try: | |
| st.session_state.pending_mcp_config = loaded_config | |
| except Exception as e: | |
| st.error(f"Failed to set initial pending config: {e}") | |
| # UI for adding individual tools | |
| st.subheader("Add Tool(JSON format)") | |
| st.markdown( | |
| """ | |
| Please insert **ONE tool** in JSON format. | |
| [How to Set Up?](https://teddylee777.notion.site/MCP-Tool-Setup-Guide-English-1d324f35d1298030a831dfb56045906a) | |
| ⚠️ **Important**: JSON must be wrapped in curly braces (`{}`). | |
| """ | |
| ) | |
| # Provide clearer example | |
| example_json = { | |
| "github": { | |
| "command": "npx", | |
| "args": [ | |
| "-y", | |
| "@smithery/cli@latest", | |
| "run", | |
| "@smithery-ai/github", | |
| "--config", | |
| '{"githubPersonalAccessToken":"your_token_here"}', | |
| ], | |
| "transport": "stdio", | |
| } | |
| } | |
| default_text = json.dumps(example_json, indent=2, ensure_ascii=False) | |
| new_tool_json = st.text_area( | |
| "Tool JSON", | |
| default_text, | |
| height=250, | |
| ) | |
| # Add button | |
| if st.button( | |
| "Add Tool", | |
| type="primary", | |
| key="add_tool_button", | |
| use_container_width=True, | |
| ): | |
| try: | |
| # Validate input | |
| if not new_tool_json.strip().startswith( | |
| "{" | |
| ) or not new_tool_json.strip().endswith("}"): | |
| st.error("JSON must start and end with curly braces ({}).") | |
| st.markdown('Correct format: `{ "tool_name": { ... } }`') | |
| else: | |
| # Parse JSON | |
| parsed_tool = json.loads(new_tool_json) | |
| # Check if it's in mcpServers format and process accordingly | |
| if "mcpServers" in parsed_tool: | |
| # Move contents of mcpServers to top level | |
| parsed_tool = parsed_tool["mcpServers"] | |
| st.info( | |
| "'mcpServers' format detected. Converting automatically." | |
| ) | |
| # Check number of tools entered | |
| if len(parsed_tool) == 0: | |
| st.error("Please enter at least one tool.") | |
| else: | |
| # Process all tools | |
| success_tools = [] | |
| for tool_name, tool_config in parsed_tool.items(): | |
| # Check URL field and set transport | |
| if "url" in tool_config: | |
| # Set transport to "streamable_http" if URL exists (preferred) or fallback to "sse" | |
| if "transport" not in tool_config: | |
| tool_config["transport"] = "streamable_http" | |
| st.info( | |
| f"URL detected in '{tool_name}' tool, setting transport to 'streamable_http' (recommended)." | |
| ) | |
| elif tool_config["transport"] == "sse": | |
| st.info( | |
| f"'{tool_name}' tool using SSE transport (deprecated but still supported)." | |
| ) | |
| elif tool_config["transport"] == "streamable_http": | |
| st.success( | |
| f"'{tool_name}' tool using Streamable HTTP transport (recommended)." | |
| ) | |
| elif tool_config["transport"] == "http": | |
| st.warning( | |
| f"'{tool_name}' tool using HTTP transport (updating to 'streamable_http' for better compatibility)." | |
| ) | |
| tool_config["transport"] = "streamable_http" | |
| elif tool_config["transport"] == "websocket": | |
| st.info( | |
| f"'{tool_name}' tool using WebSocket transport." | |
| ) | |
| elif "transport" not in tool_config: | |
| # Set default "stdio" if URL doesn't exist and transport isn't specified | |
| tool_config["transport"] = "stdio" | |
| # Check required fields | |
| if ( | |
| "command" not in tool_config | |
| and "url" not in tool_config | |
| ): | |
| st.error( | |
| f"'{tool_name}' tool configuration requires either 'command' or 'url' field." | |
| ) | |
| elif "command" in tool_config and "args" not in tool_config: | |
| st.error( | |
| f"'{tool_name}' tool configuration requires 'args' field." | |
| ) | |
| elif "command" in tool_config and not isinstance( | |
| tool_config["args"], list | |
| ): | |
| st.error( | |
| f"'args' field in '{tool_name}' tool must be an array ([]) format." | |
| ) | |
| else: | |
| # Add tool to pending_mcp_config | |
| st.session_state.pending_mcp_config[tool_name] = ( | |
| tool_config | |
| ) | |
| success_tools.append(tool_name) | |
| # Success message | |
| if success_tools: | |
| if len(success_tools) == 1: | |
| st.success( | |
| f"{success_tools[0]} tool has been added. Click 'Apply Settings' button to apply." | |
| ) | |
| else: | |
| tool_names = ", ".join(success_tools) | |
| st.success( | |
| f"Total {len(success_tools)} tools ({tool_names}) have been added. Click 'Apply Settings' button to apply." | |
| ) | |
| # Collapse expander after adding | |
| st.session_state.mcp_tools_expander = False | |
| st.rerun() | |
| except json.JSONDecodeError as e: | |
| st.error(f"JSON parsing error: {e}") | |
| st.markdown( | |
| f""" | |
| **How to fix**: | |
| 1. Check that your JSON format is correct. | |
| 2. All keys must be wrapped in double quotes ("). | |
| 3. String values must also be wrapped in double quotes ("). | |
| 4. When using double quotes within a string, they must be escaped (\\"). | |
| """ | |
| ) | |
| except Exception as e: | |
| st.error(f"Error occurred: {e}") | |
| # Display registered tools list and add delete buttons | |
| with st.expander("📋 Registered Tools List", expanded=True): | |
| try: | |
| pending_config = st.session_state.pending_mcp_config | |
| except Exception as e: | |
| st.error("Not a valid MCP tool configuration.") | |
| else: | |
| # Iterate through keys (tool names) in pending config | |
| for tool_name in list(pending_config.keys()): | |
| col1, col2 = st.columns([8, 2]) | |
| col1.markdown(f"- **{tool_name}**") | |
| if col2.button("Delete", key=f"delete_{tool_name}"): | |
| # Delete tool from pending config (not applied immediately) | |
| del st.session_state.pending_mcp_config[tool_name] | |
| st.success( | |
| f"{tool_name} tool has been deleted. Click 'Apply Settings' button to apply." | |
| ) | |
| st.divider() # Add divider | |
| # --- Sidebar: System Information and Action Buttons Section --- | |
| with st.sidebar: | |
| st.subheader("📊 System Information") | |
| st.write( | |
| f"🛠️ MCP Tools Count: {st.session_state.get('tool_count', 'Initializing...')}" | |
| ) | |
| selected_model_name = st.session_state.selected_model | |
| st.write(f"🧠 Current Model: {selected_model_name}") | |
| # Move Apply Settings button here | |
| if st.button( | |
| "Apply Settings", | |
| key="apply_button", | |
| type="primary", | |
| use_container_width=True, | |
| ): | |
| # Display applying message | |
| apply_status = st.empty() | |
| with apply_status.container(): | |
| st.warning("🔄 Applying changes. Please wait...") | |
| progress_bar = st.progress(0) | |
| # Save settings | |
| st.session_state.mcp_config_text = json.dumps( | |
| st.session_state.pending_mcp_config, indent=2, ensure_ascii=False | |
| ) | |
| # Save settings to config.json file | |
| save_result = save_config_to_json(st.session_state.pending_mcp_config) | |
| if not save_result: | |
| st.error("❌ Failed to save settings file.") | |
| progress_bar.progress(15) | |
| # Prepare session initialization | |
| st.session_state.session_initialized = False | |
| st.session_state.agent = None | |
| # Update progress | |
| progress_bar.progress(30) | |
| # Run initialization | |
| success = st.session_state.event_loop.run_until_complete( | |
| initialize_session(st.session_state.pending_mcp_config) | |
| ) | |
| # Update progress | |
| progress_bar.progress(100) | |
| if success: | |
| st.success("✅ New settings have been applied.") | |
| # Collapse tool addition expander | |
| if "mcp_tools_expander" in st.session_state: | |
| st.session_state.mcp_tools_expander = False | |
| else: | |
| st.error("❌ Failed to apply settings.") | |
| # Refresh page | |
| st.rerun() | |
| st.divider() # Add divider | |
| # Action buttons section | |
| st.subheader("🔄 Actions") | |
| # Retrieve 服务控制按钮 | |
| st.subheader("🔧 Retrieve 服务控制") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("🚀 启动服务", use_container_width=True, type="primary"): | |
| if start_retrieve_service(): | |
| st.success("✅ 服务启动成功") | |
| else: | |
| st.error("❌ 服务启动失败") | |
| st.rerun() | |
| with col2: | |
| if st.button("🛑 停止服务", use_container_width=True, type="secondary"): | |
| stop_retrieve_service() | |
| st.rerun() | |
| # 显示服务状态 | |
| if st.session_state.get("retrieve_started", False): | |
| st.success("🟢 Retrieve 服务运行中") | |
| else: | |
| st.warning("🔴 Retrieve 服务未运行") | |
| st.divider() | |
| # Reset conversation button | |
| if st.button("Reset Conversation", use_container_width=True, type="primary"): | |
| # Reset thread_id | |
| st.session_state.thread_id = random_uuid() | |
| # Reset conversation history | |
| st.session_state.history = [] | |
| # Notification message | |
| st.success("✅ Conversation has been reset.") | |
| # Refresh page | |
| st.rerun() | |
| # Show logout button only if login feature is enabled | |
| if use_login and st.session_state.authenticated: | |
| st.divider() # Add divider | |
| if st.button("Logout", use_container_width=True, type="secondary"): | |
| st.session_state.authenticated = False | |
| st.success("✅ You have been logged out.") | |
| st.rerun() | |
| # --- Initialize default session (if not initialized) --- | |
| if not st.session_state.session_initialized: | |
| st.info( | |
| "MCP server and agent are not initialized. Please click the 'Apply Settings' button in the left sidebar to initialize." | |
| ) | |
| # --- Print conversation history --- | |
| print_message() | |
| # --- User input and processing --- | |
| user_query = st.chat_input("💬 Enter your question") | |
| if user_query: | |
| if st.session_state.session_initialized: | |
| st.chat_message("user", avatar="🧑💻").markdown(user_query) | |
| with st.chat_message("assistant", avatar="🤖"): | |
| tool_placeholder = st.empty() | |
| text_placeholder = st.empty() | |
| resp, final_text, final_tool = ( | |
| st.session_state.event_loop.run_until_complete( | |
| process_query( | |
| user_query, | |
| text_placeholder, | |
| tool_placeholder, | |
| st.session_state.timeout_seconds, | |
| ) | |
| ) | |
| ) | |
| if "error" in resp: | |
| st.error(resp["error"]) | |
| else: | |
| st.session_state.history.append({"role": "user", "content": user_query}) | |
| st.session_state.history.append( | |
| {"role": "assistant", "content": final_text} | |
| ) | |
| if final_tool.strip(): | |
| st.session_state.history.append( | |
| {"role": "assistant_tool", "content": final_tool} | |
| ) | |
| st.rerun() | |
| else: | |
| st.warning( | |
| "⚠️ MCP server and agent are not initialized. Please click the 'Apply Settings' button in the left sidebar to initialize." | |
| ) | |
| # 应用退出时的清理逻辑 | |
| def cleanup_on_exit(): | |
| """应用退出时清理资源""" | |
| try: | |
| if "retrieve_process" in st.session_state and st.session_state.retrieve_process: | |
| stop_retrieve_service() | |
| except: | |
| pass | |
| # 注册清理函数 | |
| import atexit | |
| atexit.register(cleanup_on_exit) | |
| # 注意:在 Streamlit 中不能使用信号处理器,因为它在子线程中运行 | |
| # 清理逻辑通过 atexit 和页面刷新时的状态检查来处理 | |