Spaces:
Sleeping
Sleeping
import traceback | |
from fastapi import FastAPI, WebSocket | |
from fastapi.responses import FileResponse | |
import asyncio | |
from fastapi.staticfiles import StaticFiles | |
from contextlib import asynccontextmanager | |
import json | |
from fastapi import HTTPException | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import List, Optional, Any, Dict | |
from mcp_client import MCPClient | |
mcp = MCPClient() | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatCompletionRequest(BaseModel): | |
model: str = "gemini-2.5-pro-exp-03-25" | |
messages: List[ChatMessage] | |
tools: Optional[list] = [] | |
max_tokens: Optional[int] = None | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int = 0 | |
message: ChatMessage | |
finish_reason: str = "stop" | |
class ChatCompletionResponse(BaseModel): | |
id: str | |
object: str = "chat.completion" | |
created: int | |
model: str | |
choices: List[ChatCompletionResponseChoice] | |
async def lifespan(app: FastAPI): | |
try: | |
await mcp.connect() | |
print("Connexion au MCP réussi !") | |
except Exception as e: | |
print("Warning ! : Connexion au MCP impossible\n", str(e)) | |
yield | |
if mcp.session: | |
try: | |
await mcp.exit_stack.aclose() | |
print("MCP déconnecté !") | |
except Exception as e: | |
print("Erreur à la fermeture du MCP\n", str(e)) | |
app = FastAPI(lifespan=lifespan) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_credentials=True, | |
allow_headers=["*"], | |
allow_methods=["*"], | |
allow_origins=["*"] | |
) | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections = {} | |
self.response_queues = {} | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections[websocket] = None | |
def set_source(self, websocket: WebSocket, source: str): | |
if websocket in self.active_connections: | |
self.active_connections[websocket] = source | |
async def send_to_dest(self, destination: str, message: str): | |
for ws, src in self.active_connections.items(): | |
if src == destination: | |
await ws.send_text(message) | |
def remove(self, websocket: WebSocket): | |
if websocket in self.active_connections: | |
del self.active_connections[websocket] | |
async def wait_for_response(self, request_id: str, timeout: int = 30): | |
queue = asyncio.Queue(maxsize=1) | |
self.response_queues[request_id] = queue | |
try: | |
return await asyncio.wait_for(queue.get(), timeout=timeout) | |
finally: | |
self.response_queues.pop(request_id, None) | |
manager = ConnectionManager() | |
async def index_page(): | |
return FileResponse("index.html") | |
# @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) | |
# async def chat_completions(request: ChatCompletionRequest): | |
# request_id = str(uuid.uuid4()) | |
# proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == "proxy"), None) | |
# if not proxy_ws: | |
# raise HTTPException(503, "Proxy client not connected !") | |
# user_msg = next((m for m in request.messages if m.role == "user"), None) | |
# if not user_msg: | |
# raise HTTPException(400, "No user message found !") | |
# proxy_msg = { | |
# "request_id": request_id, | |
# "content": user_msg.content, | |
# "source": "api", | |
# "destination": "proxy", | |
# "model": request.model, | |
# "tools": request.tools, | |
# "max_tokens": request.max_tokens | |
# } | |
# await proxy_ws.send_text(json.dumps(proxy_msg)) | |
# try: | |
# response_content = await manager.wait_for_response(request_id) | |
# except asyncio.TimeoutError: | |
# raise HTTPException(504, "Proxy response timeout") | |
# return ChatCompletionResponse( | |
# id=request_id, | |
# created=int(time.time()), | |
# model=request.model, | |
# choices=[ChatCompletionResponseChoice( | |
# message=ChatMessage(role="assistant", content=response_content) | |
# )] | |
# ) | |
class ToolCallRequest(BaseModel): | |
tool_calls: List[Dict[str, Any]] | |
async def list_tools(): | |
if not mcp.session: | |
try: | |
await mcp.connect() | |
except Exception as e: | |
raise HTTPException(status_code=503, detail=f"Connexion au MCP impossible !\n{str(e)}") | |
try: | |
tools = await mcp.list_tools() | |
return tools | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Erreur lors de la récupération des outils: {str(e)}") | |
async def call_tools(request: ToolCallRequest): | |
if not mcp.session: | |
try: | |
await mcp.connect() | |
except Exception as e: | |
raise HTTPException(status_code=503, detail=f"Erreur lors de la récupération des outils: {str(e)}") | |
try: | |
result_tools = [] | |
for tool_call in request.tool_calls: | |
print(tool_call) | |
tool = tool_call["function"] | |
tool_name = tool["name"] | |
tool_args = tool["arguments"] | |
result = await mcp.session.call_tool(tool_name, json.loads(tool_args)) | |
result_tools.append({ | |
"role": "user", | |
"content": result.content[0].text | |
}) | |
print("Finished !") | |
return result_tools | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Erreur lors de l'appel des outils: {str(e)}") | |
async def websocket_endpoint(websocket: WebSocket): | |
await manager.connect(websocket) | |
try: | |
data = await websocket.receive_text() | |
init_msg = json.loads(data) | |
if 'source' in init_msg: | |
manager.set_source(websocket, init_msg['source']) | |
print(init_msg['source']) | |
while True: | |
message = await websocket.receive_text() | |
msg_data = json.loads(message) | |
await manager.send_to_dest(msg_data["destination"], message) | |
except Exception as e: | |
manager.remove(websocket) | |
await websocket.close() |