import os
import dspy
import mlflow
import asyncio
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
lm = dspy.LM(
model='openai/gpt-4o-mini',
temperature=0,
api_key=os.environ['OPENAI_API_KEY'],
api_base=os.environ['OPENAI_BASE_URL']
)
mcp_url = "https://pgurazada1-credit-card-database-mcp-server.hf.space/mcp/"
# IMPORTANT: Set your Hugging Face user access token in the environment variable HF_TOKEN
HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY")
if not HF_TOKEN:
raise RuntimeError("Please set your Hugging Face user access token in the HF_TOKEN environment variable.")
dspy.configure(lm=lm)
mlflow.dspy.autolog()
mlflow.set_experiment('sql-react-agent-http')
class QueryResponse(dspy.Signature):
"""
You are an expert AI assistant specialized in generating and executing SQLite queries against a database.
Your primary goal is to accurately answer user questions based *only* on the data retrieved. You must be methodical in exploring the database structure.
1. **List All Tables:** Always start with `sql_db_list_tables`.
2. **Identify Potential Tables:** List tables potentially holding the requested entities (e.g., cities, merchants) and metrics (e.g., spend). Also, identify tables that might *link* these entities (often containing ID columns like `cust_id`, `CARD_ID`, `M_ID`).
3. **Get Schemas Systematically:** Use `sql_db_schema` to get schemas for *all* tables identified in step 2. This is crucial. Do not skip potential linking tables.
4. **Map the Join Path:**
* Explicitly identify the column containing the primary metric (e.g., `transaction.TX_AMOUNT`).
* Explicitly identify the column containing the target entity (e.g., `customer.city`).
* **CRITICAL:** Trace the connections between these tables using ID columns revealed in the schemas. Look for sequences like `tableA.ID -> tableB.tableA_ID`, `tableB.ID -> tableC.tableB_ID`.
* **Example Path:** To link transaction spend to customer city, you MUST verify the path: `transaction.CARD_ID` links to `card.card_number`, AND `card.cust_id` links to `customer.cust_id`. You **MUST** request the schema for the `card` table to confirm this.
* **State the Path:** Before writing the query, state the full join path you intend to use (e.g., "Found path: transaction JOIN card ON transaction.CARD_ID = card.card_number JOIN customer ON card.cust_id = customer.cust_id").
5. **Verify Columns:** Double-check that *every* column used in your intended SELECT, JOIN, WHERE, GROUP BY, or ORDER BY clauses exists in the schemas you retrieved.
6. **Construct Query:** Build the SQLite query using the verified tables, columns, and the full, correct join path.
* Use explicit JOIN clauses (INNER JOIN is usually appropriate unless otherwise specified).
* Quote identifiers (like `"transaction"`) if they are keywords or contain special characters.
* Select only necessary columns. Alias columns for clarity if needed (e.g., `SUM(t.TX_AMOUNT) AS total_spend`).
* Include calculations like percentage contribution if requested. The total sum for percentage calculation should be derived correctly (e.g., `(SELECT SUM(TX_AMOUNT) FROM "transaction")`).
* Apply `GROUP BY` to the target entity column (e.g., `c.city`).
* Apply `ORDER BY` and `LIMIT 5` (unless otherwise specified).
7. **Validate Query:** Use `sql_db_query_checker`. Revise if syntax errors occur.
8. **Execute Query:** Use `sql_db_query`.
9. **Formulate Answer:** Base the final answer *strictly* on the query results. If the query returns no results *after* confirming a valid join path and correct syntax, state that no data matching the criteria was found.
10. **Handle Missing Information:** If, after thorough schema exploration (including checking potential linking tables), you cannot find the requested column (e.g., 'country') or a valid join path, *then and only then* inform the user the data is unavailable. Do not substitute unrelated columns.
11. **Final Answer Only:** Provide the answer directly without further tool calls once results are obtained.
1. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.).
2. DO NOT MAKE UP ANSWERS.
"""
query: str = dspy.InputField()
answer: str = dspy.OutputField(desc="The generated response to the customer query.")
async def respond(query):
async with streamablehttp_client(
url=mcp_url,
headers={"Authorization": f"Bearer {HF_TOKEN}"}
) as (read, write, _):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
# List available tools
tools_output = await session.list_tools()
# Convert MCP tools to DSPy tools
dspy_tools = []
for tool in tools_output.tools:
dspy_tools.append(dspy.Tool.from_mcp_tool(session, tool))
# Create the agent
react_agent = dspy.ReAct(QueryResponse, tools=dspy_tools, max_iters=10)
output = await react_agent.acall(query=query)
return output
# Example 1
user_query = "Who are the top 5 merchants by total number of transactions?"
pred = asyncio.run(respond(user_query))
print(pred.answer)
# Example 2
user_query = "Which is the highest spend month and amount for each card type?"
pred = asyncio.run(respond(user_query))
print(pred.answer)
# Example 3
user_query = "Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?"
pred = asyncio.run(respond(user_query))
print(pred.answer)
# Parallelism
async def main():
user_queries = [
"Who are the top 5 merchants by total transactions?",
"Which is the highest spend month and amount for each card type?",
"Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?"
]
tasks_to_run = [respond(query) for query in user_queries]
results = await asyncio.gather(*tasks_to_run)
return results
results = asyncio.run(main())
for result in results:
print(result.answer)