import os from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Optional, AsyncIterator import asyncpg from mcp.server.fastmcp import FastMCP, Context from pydantic import Field import pandasai as pai import matplotlib as plt import pandas as pd import logging from pandasai_openai import OpenAI # Constants DEFAULT_QUERY_LIMIT = 100 # logging info # logging.basicConfig(level=logging.INFO) # get logger logger = logging.getLogger(__name__) # Define our own PromptMessage class if the MCP one isn't available @dataclass class PromptMessage: content: str role: Optional[str] = "user" # Database context class @dataclass class DbContext: pool: asyncpg.Pool schema: str # Database connection lifecycle manager @asynccontextmanager async def db_lifespan(server: FastMCP) -> AsyncIterator[DbContext]: """Manage database connection lifecycle""" dsn = os.environ["DB_URL"] db_schema = os.environ["DB_SCHEMA"] pool = await asyncpg.create_pool( dsn, min_size=1, max_size=4, max_inactive_connection_lifetime=300, timeout=60, command_timeout=300, ) try: yield DbContext(pool=pool, schema=db_schema) finally: # Clean up await pool.close() # Create server with database lifecycle management mcp = FastMCP( "SQL Database Server", dependencies=["asyncpg", "pydantic"], lifespan=db_lifespan ) @mcp.resource( uri="resource://base_prompt", name="base_prompt", description="A base prompt to generate SQL queries and answer questions" ) async def base_prompt_query() -> str: """Returns a base prompt to generate sql queries and answer questions""" base_prompt = """ ========================== # Your Role ========================== You are an expert in generating and executing SQL queries, interacting with a PostgreSQL database using **FastMCP tools**, and visualizing results when requested. These tools allow you to: - List available tables - Retrieve schema details - Execute SQL queries - Visualize query results using PandasAI Each tool may return previews or summaries of table contents to help you understand the data structure. --- ========================== # Your Objective ========================== When a user submits a request, you must: 1. **Analyze the request** to determine the required data, action, or visualization. 2. **Use FastMCP tools** to gather necessary information (e.g., list tables, retrieve schema). 3. **Generate a valid SQL SELECT query**, if needed, and clearly show the full query. 4. **Execute the SQL query** using the `execute_query` tool and return the results. 5. **Visualize results** using the `visualize_results` tool if the user explicitly requests a visualization (e.g., "create a chart", "visualize", "plot"). For visualizations: - Craft a visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region"). - Send JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`. 6. **Chain tools logically**, such as: List Tables → Get Schema → Write and Run Query → Visualize Results (if requested). 7. **Explain your reasoning and each step taken** to ensure clarity and transparency. --- ========================== # Critical Rules ========================== - Only use **read-only** SQL queries such as **SELECT**, **COUNT**, or queries with **GROUP BY**, **ORDER BY**, etc. - **Never** use destructive operations like **DELETE**, **UPDATE**, **INSERT**, or **DROP**. - Always show the SQL query you generate along with the execution result. - Validate SQL syntax before execution. - Never assume table or column names. Use tools to confirm structure. - Use memory efficiently. Don't rerun a tool unless necessary. - If you generate a SQL query, immediately call the **execute_query** tool. - If the user requests a visualization, call the **visualize_results** tool with: - A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region"). - JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`. - For non-query or non-visualization requests (e.g., history questions), respond appropriately without forcing a query or visualization. --- ========================== # Database Description ========================== {descriptions} --- ========================== # Tools ========================== You can use the following FastMCP tools to create **read-only** queries (e.g., `SELECT`, `COUNT`, `GROUP BY`, `ORDER BY`) and visualize results when requested. Chain tools logically to gather information, execute queries, or generate visualizations. {tools} --- ### Invalid Example — DELETE Operation (Not Allowed): **User Request:** "Delete all customers from Germany." **Response Guidance:** - **Do not generate or execute** destructive queries such as `DELETE`. - Instead, respond with a message like: > Destructive operations such as `DELETE` are not permitted. I can help you retrieve the customers from Germany using a `SELECT` query instead: > ```sql > SELECT * FROM customers WHERE country = 'Germany'; > ``` ========================== # Output Format ========================== Present your final answer using the following structure in markdown language: # Result {{Take the result from the execute_query tool and format it nicely using Markdown. Use a beautiful Markdown table for tabular data (rows and columns) including headers and show such simple results using a table. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}} # Explanation {{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}} # Query ```sql {{Display the exact SQL query you generated and executed here to answer the user's request.}} ``` ========================== # Reminder ========================== - **Every time you generate a SQL query, call `execute_query` immediately and include the result.** - **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:** - A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region"). - JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`. - **For non-query or non-visualization requests, respond appropriately without forcing a query or visualization.** **Conversation History:** Use the conversation history for context when available to maintain continuity. """ return base_prompt @mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.") async def test_connection(ctx: Context) -> str: """Test database connection""" try: pool = ctx.request_context.lifespan_context.pool async with pool.acquire() as conn: version = await conn.fetchval("SELECT version();") return f"Connection successful. PostgreSQL version: {version}" except Exception as e: return f"Connection failed: {str(e)}" @mcp.tool(description="Executes a read-only SQL SELECT query and returns formatted results or an error message.") async def execute_query( query: str = Field(description="SQL query to execute (SELECT only)"), limit: Optional[int] = Field(default=DEFAULT_QUERY_LIMIT, description="Maximum number of rows to return"), ctx: Context = None ) -> str: """Execute a read-only SQL query against the database""" # Validate query - simple check for read-only query = query.strip() if not query.lower().startswith("select"): return "Error: Only SELECT queries are allowed for security reasons." try: pool = ctx.request_context.lifespan_context.pool async with pool.acquire() as conn: result = await conn.fetch(query) if not result: return "Query executed successfully. No rows returned." # Format results columns = [k for k in result[0].keys()] header = " | ".join(columns) separator = "-" * len(header) # Format rows with limit rows = [" | ".join(str(val) for val in row.values()) for row in result[:limit if limit else DEFAULT_QUERY_LIMIT]] # print(f"{header}\n{separator}\n" + "\n".join(rows)) # print(f"===== Header Type: ======\n {type(header)}") # print(f"===== Row Type: ======\n {type(rows)}") # # # print the data itself # print(f"===== Header Data: ======\n {header}") # print(f"===== Row Data: ======\n {rows}") return f"{header}\n{separator}\n" + "\n".join(rows) except asyncpg.exceptions.PostgresError as e: return f"SQL Error: {str(e)}" except Exception as e: return f"Error: {str(e)}" # Database helper functions async def get_all_tables(pool, db_schema): """Get all tables from the database""" print(f"schema: {db_schema}") async with pool.acquire() as conn: result = await conn.fetch(""" SELECT c.relname AS table_name FROM pg_class AS c JOIN pg_namespace AS n ON n.oid = c.relnamespace WHERE NOT EXISTS ( SELECT 1 FROM pg_inherits AS i WHERE i.inhrelid = c.oid ) AND c.relkind IN ('r', 'p') AND n.nspname = $1 AND c.relname NOT LIKE 'pg_%' ORDER BY c.relname; """, db_schema) return result async def get_table_schema_info(pool, db_schema, table_name): """Get schema information for a specific table""" async with pool.acquire() as conn: columns = await conn.fetch(""" SELECT column_name, data_type, is_nullable, column_default, character_maximum_length FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position; """, db_schema, table_name) return columns def format_table_schema(table_name, columns): """Format table schema into readable text""" if not columns: return f"Table '{table_name}' not found." result = [f"Table: {table_name}", "Columns:"] for col in columns: nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL" length = f"({col['character_maximum_length']})" if col['character_maximum_length'] else "" default = f" DEFAULT {col['column_default']}" if col['column_default'] else "" result.append(f"- {col['column_name']} ({col['data_type']}{length}) {nullable}{default}") return "\n".join(result) @mcp.tool(description="Lists all table names in the configured database schema.") async def list_tables() -> str: """List all tables in the database""" try: async with db_lifespan(mcp) as db_ctx: result = await get_all_tables(db_ctx.pool, db_ctx.schema) if not result: return f"No tables found in the {db_ctx.schema} schema." return "\n".join(row['table_name'] for row in result) except asyncpg.exceptions.PostgresError as e: return f"SQL Error: {str(e)}" except Exception as e: return f"Error: {str(e)}" @mcp.tool(description="Retrieves and formats the schema details of a specific table in the database.") async def get_table_schema(table_name: str) -> str: """Get schema information for a specific table""" try: db_schema = os.environ["DB_SCHEMA"] async with db_lifespan(mcp) as db_ctx: columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name) if not columns: return f"Table '{table_name}' not found in {db_schema} schema." return format_table_schema(table_name, columns) except asyncpg.exceptions.PostgresError as e: return f"SQL Error: {str(e)}" except Exception as e: return f"Error: {str(e)}" @mcp.tool(description="Retrieves foreign key relationships for a specified table.") def get_foreign_keys(table_name: str) -> str: """Get foreign key information for a table. Args: table_name: The name of the table to get foreign keys from schema: The schema name (defaults to 'public') """ db_schema = os.environ["DB_SCHEMA"] sql = """ SELECT tc.constraint_name, kcu.column_name as fk_column, ccu.table_schema as referenced_schema, ccu.table_name as referenced_table, ccu.column_name as referenced_column FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.referential_constraints rc ON tc.constraint_name = rc.constraint_name JOIN information_schema.constraint_column_usage ccu ON rc.unique_constraint_name = ccu.constraint_name WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = {db_schema} AND tc.table_name = {table_name} ORDER BY tc.constraint_name, kcu.ordinal_position """ return execute_query(sql, (db_schema, table_name)) @mcp.tool(description="Fetches and formats the schema details for all tables in the configured database schema.") async def get_all_schemas() -> str: """Get schema information for all tables in the database""" try: db_schema = os.environ["DB_SCHEMA"] async with db_lifespan(mcp) as db_ctx: tables = await get_all_tables(db_ctx.pool, db_ctx.schema) if not tables: return f"No tables found in the {db_ctx.schema} schema." all_schemas = [] for table in tables: table_name = table['table_name'] columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name) table_schema = format_table_schema(table_name, columns) all_schemas.append(table_schema) all_schemas.append("") # Add empty line between tables return "\n".join(all_schemas) except asyncpg.exceptions.PostgresError as e: return f"SQL Error: {str(e)}" except Exception as e: return f"Error: {str(e)}" @mcp.prompt(description="Generates a prompt message to help craft a best-practice SELECT query for a given table.") async def generate_select_query(table_name: str) -> list[PromptMessage]: """Generate a SELECT query with best practices for a table""" try: async with db_lifespan(mcp) as db_ctx: pool = db_ctx.pool async with pool.acquire() as conn: columns = await conn.fetch(""" SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position """, db_ctx.schema, table_name) if not columns: return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")] columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns]) return [ PromptMessage( f"""Please help me write a well-structured, efficient SELECT query for the '{table_name}' table. Table Schema: {columns_text} PostgreSQL SQL Best Practices: - Use explicit column names instead of * when possible - Include LIMIT clauses to restrict result sets - Consider adding WHERE clauses to filter results - Use appropriate indexing considerations - Format SQL with proper indentation and line breaks Create a basic SELECT query following these best practices:""" ) ] except Exception as e: return [PromptMessage(f"Error generating select query: {str(e)}")] @mcp.prompt(description="Generates a prompt message to assist in writing analytical queries for a given table.") async def generate_analytical_query(table_name: str) -> list[PromptMessage]: """ Generate analytical queries for a table Args: table_name: The name of the table to generate analytical queries for """ db_schema = os.environ["DB_SCHEMA"] try: async with db_lifespan(mcp) as db_ctx: pool = db_ctx.pool async with pool.acquire() as conn: columns = await conn.fetch(f""" SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = {db_schema} AND table_name = {table_name} ORDER BY ordinal_position """, db_ctx.schema, table_name) if not columns: return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")] columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns]) return [ PromptMessage( f"""Please help me create analytical queries for the '{table_name}' table. Table Schema: {columns_text} PostgreSQL SQL Best Practices: - Use aggregation functions (COUNT, SUM, AVG, MIN, MAX) appropriately - Group data using GROUP BY for meaningful aggregations - Filter groups with HAVING clauses when needed - Consider using window functions for advanced analytics - Format SQL with proper indentation and line breaks Create a set of analytical queries for this table:""" ) ] except Exception as e: return [PromptMessage(f"Error generating analytical query: {str(e)}")] @mcp.tool( description="Identifies both explicit and implied foreign key relationships for a given table using schema analysis and naming patterns.") def find_relationships(table_name: str, db_schema: str = 'public') -> str: """Find both explicit and implied relationships for a table. Args: table_name: The name of the table to analyze relationships for db_schema: The schema name (defaults to 'public') """ try: # First get explicit foreign key relationships fk_sql = f""" SELECT kcu.column_name, ccu.table_name as foreign_table, ccu.column_name as foreign_column, 'Explicit FK' as relationship_type, 1 as confidence_level FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = {db_schema} AND tc.table_name = {table_name} """ # Then look for implied relationships based on common patterns implied_sql = f""" WITH source_columns AS ( -- Get all ID-like columns from our table SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = {db_schema} AND table_name = {table_name} AND ( column_name LIKE '%%id' OR column_name LIKE '%%_id' OR column_name LIKE '%%_fk' ) ), potential_references AS ( -- Find tables that might be referenced by our ID columns SELECT DISTINCT sc.column_name as source_column, sc.data_type as source_type, t.table_name as target_table, c.column_name as target_column, c.data_type as target_type, CASE -- Highest confidence: column matches table_id pattern and types match WHEN sc.column_name = t.table_name || '_id' AND sc.data_type = c.data_type THEN 2 -- High confidence: column ends with _id and types match WHEN sc.column_name LIKE '%%_id' AND sc.data_type = c.data_type THEN 3 -- Medium confidence: column contains table name and types match WHEN sc.column_name LIKE '%%' || t.table_name || '%%' AND sc.data_type = c.data_type THEN 4 -- Lower confidence: column ends with id and types match WHEN sc.column_name LIKE '%%id' AND sc.data_type = c.data_type THEN 5 END as confidence_level FROM source_columns sc CROSS JOIN information_schema.tables t JOIN information_schema.columns c ON c.table_schema = t.table_schema AND c.table_name = t.table_name AND (c.column_name = 'id' OR c.column_name = sc.column_name) WHERE t.table_schema = {db_schema} AND t.table_name != {table_name} -- Exclude self-references ) SELECT source_column as column_name, target_table as foreign_table, target_column as foreign_column, CASE WHEN confidence_level = 2 THEN 'Strong implied relationship (exact match)' WHEN confidence_level = 3 THEN 'Strong implied relationship (_id pattern)' WHEN confidence_level = 4 THEN 'Likely implied relationship (name match)' ELSE 'Possible implied relationship' END as relationship_type, confidence_level FROM potential_references WHERE confidence_level IS NOT NULL ORDER BY confidence_level, source_column; """ # Execute both queries and combine results fk_results = execute_query(fk_sql) implied_results = execute_query(implied_sql) # If both queries returned "No results found", return that if fk_results == "No results found" and implied_results == "No results found": return "No relationships found for this table" # Otherwise, return both sets of results return f"Explicit Foreign Keys:\n{fk_results}\n\nImplied Relationships:\n{implied_results}" except Exception as e: return f"Error finding relationships: {str(e)}" @mcp.tool(description="Visualizes query results using a prompt and JSON data.") async def visualize_results(json_data: dict, vis_prompt: str) -> str: """ Generates a visualization based on query results using PandasAI. Args: json_data (dict): A dictionary containing the query results. It should have two keys: - 'columns': A list of column names (strings). - 'data': A list of lists, where each inner list represents a row of data. Each element in the inner list corresponds to a column in 'columns'. Example: { 'columns': ['Region', 'Product', 'Sales'], 'data': [ ['North', 'Widget', 150], ['South', 'Widget', 200] ] } vis_prompt (str): A natural language prompt describing the desired visualization (e.g., "Create a bar chart showing sales by region"). Returns: str: The path to the saved visualization file (e.g., 'visualization_output.png') or an error message if the visualization fails. """ try: # Debug prints to see what's being received OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") if OPENAI_API_KEY: pllm = OpenAI(api_token=OPENAI_API_KEY) pai.config.set({"llm": pllm}) # =============== DEPRECATED CODE BLOCK============================== # Convert JSON to DataFrame # df = pd.DataFrame(json_data["data"], columns=json_data["columns"]) # Shorten long values in text columns # for column in df.select_dtypes(include=['object']).columns: # df[column] = df[column].apply(lambda x: str(x)[:20] + '...' if len(str(x)) > 20 else str(x)) # =============== END OF CODE BLOCK===================================== # Initialize PandasAI df_ai = pai.DataFrame(data=json_data["data"], columns=json_data["columns"]) api_key = os.environ["PANDAS_KEY"] pai.api_key.set(api_key) # Generate visualization df_ai.chat(vis_prompt) # Get the visualization path PANDAS_EXPORTS_PATH = os.environ["PANDAS_EXPORTS_PATH"] generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.startswith("temp_chart")] if generated_files: visualization_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0]) return f"Visualization saved as {visualization_path}" except Exception as e: return f"Visualization error: {str(e)}" if __name__ == "__main__": mcp.run()