|
import asyncio |
|
import os |
|
import json |
|
from typing import List, Dict, Any, Union |
|
from contextlib import AsyncExitStack |
|
from datetime import datetime |
|
import gradio as gr |
|
from gradio.components.chatbot import ChatMessage |
|
from mcp import ClientSession, StdioServerParameters |
|
from mcp.client.stdio import stdio_client |
|
from mcp.client.sse import sse_client |
|
from anthropic import Anthropic |
|
from anthropic._exceptions import OverloadedError |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
SYSTEM_PROMPT = f"""You are a helpful assistant. Today is {datetime.now().strftime("%Y-%m-%d")}. |
|
|
|
You **do not** have prior knowledge of the World Development Indicators (WDI) data. Instead, you must rely entirely on the tools available to you to answer the user's questions. |
|
|
|
When responding you must always plan the steps and enumerate all the tools that you plan to use to answer the user's query. |
|
|
|
### Your Instructions: |
|
|
|
1. **Tool Use Only**: |
|
- You must not provide any answers based on prior knowledge or assumptions. |
|
- You must **not** fabricate data or simulate the behavior of the `get_wdi_data` tool. |
|
- You cannot use the `get_wdi_data` tool without using the `search_relevant_indicators` tool first. |
|
- If the user requests WDI data, you **MUST ALWAYS** first call the `search_relevant_indicators` tool to see if there's any relevant data. |
|
- If relevant data exists, call the `get_wdi_data` tool to get the data. |
|
|
|
2. **Tool Invocation**: |
|
- Use any relevant tools provided to you to answer the user's question. |
|
- You may call multiple tools if needed, and you should do so in a logical sequence to minimize unnecessary user interaction. |
|
- Do not hesitate to invoke tools as soon as they are relevant. |
|
|
|
3. **Limitations**: |
|
- If a user request cannot be fulfilled using the tools available, respond by clearly stating that you do not have access to that information. |
|
|
|
4. **Ethical Guidelines**: |
|
- Do not make or endorse statements based on stereotypes, bias, or assumptions. |
|
- Ensure all claims and explanations are grounded in the data or factual evidence retrieved via tools. |
|
- Politely refuse to respond to requests that involve stereotypes or unfounded generalizations. |
|
|
|
5. **Communication Style**: |
|
- Present the data in clear, user-friendly language. |
|
- You may summarize or explain the data retrieved, but do **not** elaborate based on outside or implicit knowledge. |
|
- You may describe the data in a way that is easy to understand but you MUST NOT elaborate based on external knowledge. |
|
|
|
Stay strictly within these boundaries while maintaining a helpful and respectful tone.""" |
|
|
|
|
|
LLM_MODEL = "claude-3-5-haiku-20241022" |
|
|
|
|
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
|
|
class MCPClientWrapper: |
|
def __init__(self): |
|
self.session = None |
|
self.exit_stack = None |
|
self.anthropic = Anthropic() |
|
self.tools = [] |
|
|
|
async def connect(self, server_path_or_url: str) -> str: |
|
|
|
if self.exit_stack: |
|
await self.exit_stack.aclose() |
|
|
|
self.exit_stack = AsyncExitStack() |
|
|
|
if server_path_or_url.endswith(".py"): |
|
command = "python" |
|
|
|
server_params = StdioServerParameters( |
|
command=command, |
|
args=[server_path_or_url], |
|
env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}, |
|
) |
|
|
|
|
|
stdio_transport = await self.exit_stack.enter_async_context( |
|
stdio_client(server_params) |
|
) |
|
self.stdio, self.write = stdio_transport |
|
else: |
|
sse_transport = await self.exit_stack.enter_async_context( |
|
sse_client( |
|
server_path_or_url, |
|
headers={"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}, |
|
) |
|
) |
|
self.stdio, self.write = sse_transport |
|
|
|
|
|
self.session = await self.exit_stack.enter_async_context( |
|
ClientSession(self.stdio, self.write) |
|
) |
|
await self.session.initialize() |
|
|
|
response = await self.session.list_tools() |
|
self.tools = [ |
|
{ |
|
"name": tool.name, |
|
"description": tool.description, |
|
"input_schema": tool.inputSchema, |
|
} |
|
for tool in response.tools |
|
] |
|
|
|
print("Available tools:", self.tools) |
|
tool_names = [tool["name"] for tool in self.tools] |
|
return f"Connected to MCP server. Available tools: {', '.join(tool_names)}" |
|
|
|
async def process_message( |
|
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]] |
|
): |
|
if not self.session: |
|
messages = history + [ |
|
{"role": "user", "content": message}, |
|
{ |
|
"role": "assistant", |
|
"content": "Please connect to an MCP server first.", |
|
}, |
|
] |
|
yield messages, gr.Textbox(value="") |
|
else: |
|
messages = history + [{"role": "user", "content": message}] |
|
|
|
yield messages, gr.Textbox(value="") |
|
|
|
async for partial in self._process_query(message, history): |
|
messages.extend(partial) |
|
|
|
yield messages, gr.Textbox(value="") |
|
|
|
if ( |
|
messages[-1]["role"] == "assistant" |
|
and messages[-1]["content"] |
|
== "The LLM API is overloaded now, try again later..." |
|
): |
|
break |
|
|
|
with open("messages.log.jsonl", "a+") as fl: |
|
fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages))) |
|
|
|
async def _process_query( |
|
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]] |
|
): |
|
claude_messages = [] |
|
for msg in history: |
|
if isinstance(msg, ChatMessage): |
|
role, content = msg.role, msg.content |
|
else: |
|
role, content = msg.get("role"), msg.get("content") |
|
|
|
if role in ["user", "assistant", "system"]: |
|
claude_messages.append({"role": role, "content": content}) |
|
|
|
claude_messages.append({"role": "user", "content": message}) |
|
|
|
try: |
|
response = self.anthropic.messages.create( |
|
|
|
model=LLM_MODEL, |
|
system=SYSTEM_PROMPT, |
|
max_tokens=1000, |
|
messages=claude_messages, |
|
tools=self.tools, |
|
) |
|
except OverloadedError: |
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "The LLM API is overloaded now, try again later...", |
|
} |
|
] |
|
|
|
|
|
result_messages = [] |
|
partial_messages = [] |
|
|
|
print(response.content) |
|
contents = response.content |
|
|
|
MAX_CALLS = 10 |
|
auto_calls = 0 |
|
|
|
while len(contents) > 0 and auto_calls < MAX_CALLS: |
|
content = contents.pop(0) |
|
|
|
if content.type == "text": |
|
result_messages.append({"role": "assistant", "content": content.text}) |
|
claude_messages.append({"role": "assistant", "content": content.text}) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
partial_messages = [] |
|
|
|
elif content.type == "tool_use": |
|
tool_id = content.id |
|
tool_name = content.name |
|
tool_args = content.input |
|
|
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": f"I'll use the {tool_name} tool to help answer your question.", |
|
"metadata": { |
|
"title": f"Using tool: {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", |
|
"log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}", |
|
"status": "pending", |
|
"id": f"tool_call_{tool_name}", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
|
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": "```json\n" |
|
+ json.dumps(tool_args, indent=2, ensure_ascii=True) |
|
+ "\n```", |
|
"metadata": { |
|
"parent_id": f"tool_call_{tool_name}", |
|
"id": f"params_{tool_name}", |
|
"title": "Tool Parameters", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
|
|
print(f"Calling tool: {tool_name} with args: {tool_args}") |
|
result = await self.session.call_tool(tool_name, tool_args) |
|
|
|
if result_messages and "metadata" in result_messages[-2]: |
|
result_messages[-2]["metadata"]["status"] = "done" |
|
|
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": "Here are the results from the tool:", |
|
"metadata": { |
|
"title": f"Tool Result for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", |
|
"status": "done", |
|
"id": f"result_{tool_name}", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
partial_messages = [] |
|
|
|
result_content = result.content |
|
print(result_content) |
|
if isinstance(result_content, list): |
|
result_content = [r.model_dump() for r in result_content] |
|
|
|
for r in result_content: |
|
|
|
r.pop("annotations", None) |
|
try: |
|
r["text"] = json.loads(r["text"]) |
|
except: |
|
pass |
|
|
|
print("result_content", result_content) |
|
|
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": "```\n" |
|
+ json.dumps(result_content, indent=2) |
|
+ "\n```", |
|
"metadata": { |
|
"parent_id": f"result_{tool_name}", |
|
"id": f"raw_result_{tool_name}", |
|
"title": "Raw Output", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
partial_messages = [] |
|
|
|
claude_messages.append( |
|
{"role": "assistant", "content": [content.model_dump()]} |
|
) |
|
claude_messages.append( |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "tool_result", |
|
"tool_use_id": tool_id, |
|
"content": json.dumps(result_content, indent=2), |
|
} |
|
], |
|
} |
|
) |
|
|
|
try: |
|
next_response = self.anthropic.messages.create( |
|
model=LLM_MODEL, |
|
system=SYSTEM_PROMPT, |
|
max_tokens=1000, |
|
messages=claude_messages, |
|
tools=self.tools, |
|
) |
|
auto_calls += 1 |
|
except OverloadedError: |
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "The LLM API is overloaded now, try again later...", |
|
} |
|
] |
|
|
|
print("next_response", next_response.content) |
|
|
|
contents.extend(next_response.content) |
|
|
|
|
|
def gradio_interface( |
|
server_path_or_url: str = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse", |
|
): |
|
|
|
|
|
|
|
client = MCPClientWrapper() |
|
|
|
with gr.Blocks(title="WDI MCP Client") as demo: |
|
gr.Markdown("## Ask about the World Development Indicators (WDI) data") |
|
|
|
|
|
with gr.Accordion( |
|
"Connect to the WDI MCP server and chat with the assistant", |
|
open=False, |
|
visible=server_path_or_url.endswith(".py"), |
|
): |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=4): |
|
server_path = gr.Textbox( |
|
label="Server Script Path", |
|
placeholder="Enter path to server script (e.g., wdi_mcp_server.py)", |
|
value=server_path_or_url, |
|
) |
|
with gr.Column(scale=1): |
|
connect_btn = gr.Button("Connect") |
|
|
|
status = gr.Textbox(label="Connection Status", interactive=False) |
|
|
|
chatbot = gr.Chatbot( |
|
value=[], |
|
height=600, |
|
type="messages", |
|
show_copy_button=True, |
|
avatar_images=("img/small-user.png", "img/small-robot.png"), |
|
autoscroll=True, |
|
) |
|
|
|
with gr.Row(equal_height=True): |
|
msg = gr.Textbox( |
|
label="Your Question", |
|
placeholder="Ask about what indicators are available for a specific topic (e.g., What's the definition of GDP?)", |
|
scale=4, |
|
) |
|
clear_btn = gr.Button("Clear Chat", scale=1) |
|
|
|
connect_btn.click(client.connect, inputs=server_path, outputs=status) |
|
|
|
demo.load(fn=client.connect, inputs=server_path, outputs=status) |
|
|
|
msg.submit(client.process_message, [msg, chatbot], [chatbot, msg]) |
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
if not os.getenv("ANTHROPIC_API_KEY"): |
|
print( |
|
"Warning: ANTHROPIC_API_KEY not found in environment. Please set it in your .env file." |
|
) |
|
|
|
interface = gradio_interface() |
|
interface.launch(server_name=os.getenv("SERVER_NAME", "127.0.0.1"), debug=True) |
|
|