import os import os.path import json from typing import Tuple, Any from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from langchain_mcp_adapters.tools import load_mcp_tools from langgraph.prebuilt import create_react_agent from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, SystemMessage from langchain_community.chat_message_histories import FileChatMessageHistory from langchain.chat_models import init_chat_model import logging from langchain.globals import set_debug from langchain_community.chat_message_histories import ChatMessageHistory from memory_store import MemoryStore from dotenv import load_dotenv load_dotenv() # set_debug(True) # Set up logging logger = logging.getLogger(__name__) async def lc_mcp_exec(request: str, history=None) -> Tuple[str, list]: """ Execute the PostgreSQL MCP pipeline with in-memory chat history. Returns the response and the updated message history. """ try: # Get the singleton memory store instance message_history = MemoryStore.get_memory() # Load table summary and server parameters table_summary = load_table_summary(os.environ["TABLE_SUMMARY_PATH"]) server_params = get_server_params() OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") if OPENAI_API_KEY: # Initialize the LLM for OpenAI llm = init_chat_model( model_provider=os.environ["OPENAI_MODEL_PROVIDER"], model=os.environ["OPENAI_MODEL"], api_key=OPENAI_API_KEY ) else: # Initialize the LLM for Gemini llm = init_chat_model( model_provider=os.environ["GEMINI_MODEL_PROVIDER"], model=os.environ["GEMINI_MODEL"], api_key=os.environ["GEMINI_API_KEY"] ) # Initialize the MCP client async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() # Load tools and create the agent tools = await load_and_enrich_tools(session) agent = create_react_agent(llm, tools) # clear the memory if request == "/clear-cache": message_history.clear() return "Memory cleared", [] # Add new user message to memory message_history.add_user_message(request) # Get system prompt and create system message system_prompt = await build_prompt(session, tools, table_summary) system_message = SystemMessage(content=system_prompt) # Combine system message with chat history input_messages = [system_message] + message_history.messages # Invoke agent agent_response = await agent.ainvoke( {"messages": input_messages}, config={"configurable": {"thread_id": "conversation_123"}} ) # Process agent response response_content = "No response generated" if "messages" in agent_response and agent_response["messages"]: new_messages = agent_response["messages"][len(input_messages):] # Save new messages to memory for msg in new_messages: if isinstance(msg, (AIMessage, ToolMessage)): message_history.add_message(msg) else: logger.debug(f"Skipping unexpected message type: {type(msg)}") response_content = agent_response["messages"][-1].content else: message_history.add_ai_message(response_content) return response_content, message_history.messages except Exception as e: logger.error(f"Error in execution: {str(e)}", exc_info=True) return f"Error: {str(e)}", [] # ---------------- Helper Functions ---------------- # def load_table_summary(path: str) -> str: with open(path, 'r') as file: return file.read() def get_server_params() -> StdioServerParameters: # Prepare the environment dictionary to pass to the subprocess subprocess_env = {} # List of environment variables that the postgre_mcp_server.py needs required_vars_for_server = [ # "TABLE_SUMMARY_PATH", "DB_URL", "DB_SCHEMA", "PANDAS_KEY", "PANDAS_EXPORTS_PATH", # "GEMINI_API_KEY", # "GEMINI_MODEL", # "GEMINI_MODEL_PROVIDER", # "OPENAI_MODEL_PROVIDER", # "OPENAI_MODEL", "OPENAI_API_KEY", ] for var_name in required_vars_for_server: value = os.getenv(var_name) if value is not None: subprocess_env[var_name] = value else: logger.warning(f"Environment variable {var_name} not found for passing to MCP server subprocess.") logger.info(f"Passing environment to MCP server subprocess: {subprocess_env.keys()}") return StdioServerParameters( command="python", args=[os.environ["MCP_SERVER_PATH"]], # MCP_SERVER_PATH itself must be available to this client env=subprocess_env ) async def load_and_enrich_tools(session: ClientSession): tools = await load_mcp_tools(session) return tools async def build_prompt(session, tools, table_summary): conversation_prompt = await session.read_resource("resource://base_prompt") template = conversation_prompt.contents[0].text tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools]) return template.format( tools=tools_str, descriptions=table_summary, )