MCPSynapseChat / app.py
Omar ID EL MOUMEN
Final version
8227e25
import traceback
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse
import asyncio
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
import json
from fastapi import HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Optional, Any, Dict
from mcp_client import MCPClient
mcp = MCPClient()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "gemini-2.5-pro-exp-03-25"
messages: List[ChatMessage]
tools: Optional[list] = []
max_tokens: Optional[int] = None
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
await mcp.connect()
print("Connexion au MCP réussi !")
except Exception as e:
print("Warning ! : Connexion au MCP impossible\n", str(e))
yield
if mcp.session:
try:
await mcp.exit_stack.aclose()
print("MCP déconnecté !")
except Exception as e:
print("Erreur à la fermeture du MCP\n", str(e))
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=["*"],
allow_methods=["*"],
allow_origins=["*"]
)
class ConnectionManager:
def __init__(self):
self.active_connections = {}
self.response_queues = {}
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections[websocket] = None
def set_source(self, websocket: WebSocket, source: str):
if websocket in self.active_connections:
self.active_connections[websocket] = source
async def send_to_dest(self, destination: str, message: str):
for ws, src in self.active_connections.items():
if src == destination:
await ws.send_text(message)
def remove(self, websocket: WebSocket):
if websocket in self.active_connections:
del self.active_connections[websocket]
async def wait_for_response(self, request_id: str, timeout: int = 30):
queue = asyncio.Queue(maxsize=1)
self.response_queues[request_id] = queue
try:
return await asyncio.wait_for(queue.get(), timeout=timeout)
finally:
self.response_queues.pop(request_id, None)
manager = ConnectionManager()
@app.get("/")
async def index_page():
return FileResponse("index.html")
# @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
# async def chat_completions(request: ChatCompletionRequest):
# request_id = str(uuid.uuid4())
# proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == "proxy"), None)
# if not proxy_ws:
# raise HTTPException(503, "Proxy client not connected !")
# user_msg = next((m for m in request.messages if m.role == "user"), None)
# if not user_msg:
# raise HTTPException(400, "No user message found !")
# proxy_msg = {
# "request_id": request_id,
# "content": user_msg.content,
# "source": "api",
# "destination": "proxy",
# "model": request.model,
# "tools": request.tools,
# "max_tokens": request.max_tokens
# }
# await proxy_ws.send_text(json.dumps(proxy_msg))
# try:
# response_content = await manager.wait_for_response(request_id)
# except asyncio.TimeoutError:
# raise HTTPException(504, "Proxy response timeout")
# return ChatCompletionResponse(
# id=request_id,
# created=int(time.time()),
# model=request.model,
# choices=[ChatCompletionResponseChoice(
# message=ChatMessage(role="assistant", content=response_content)
# )]
# )
class ToolCallRequest(BaseModel):
tool_calls: List[Dict[str, Any]]
@app.get("/list-tools", response_model=List[Dict[str, Any]])
async def list_tools():
if not mcp.session:
try:
await mcp.connect()
except Exception as e:
raise HTTPException(status_code=503, detail=f"Connexion au MCP impossible !\n{str(e)}")
try:
tools = await mcp.list_tools()
return tools
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur lors de la récupération des outils: {str(e)}")
@app.post("/call-tools")
async def call_tools(request: ToolCallRequest):
if not mcp.session:
try:
await mcp.connect()
except Exception as e:
raise HTTPException(status_code=503, detail=f"Erreur lors de la récupération des outils: {str(e)}")
try:
result_tools = []
for tool_call in request.tool_calls:
print(tool_call)
tool = tool_call["function"]
tool_name = tool["name"]
tool_args = tool["arguments"]
result = await mcp.session.call_tool(tool_name, json.loads(tool_args))
result_tools.append({
"role": "user",
"content": result.content[0].text
})
print("Finished !")
return result_tools
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur lors de l'appel des outils: {str(e)}")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
data = await websocket.receive_text()
init_msg = json.loads(data)
if 'source' in init_msg:
manager.set_source(websocket, init_msg['source'])
print(init_msg['source'])
while True:
message = await websocket.receive_text()
msg_data = json.loads(message)
await manager.send_to_dest(msg_data["destination"], message)
except Exception as e:
manager.remove(websocket)
await websocket.close()