|
|
import asyncio |
|
|
import os |
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException |
|
|
from fastapi.responses import HTMLResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Dict |
|
|
import uvicorn |
|
|
import json |
|
|
import time |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
class MockRequest(BaseModel): |
|
|
"""Définit la structure attendue pour le corps de la requête POST.""" |
|
|
parameter: str |
|
|
model: str = None |
|
|
secret: str |
|
|
|
|
|
class ConnectionManager: |
|
|
"""Gère les connexions WebSocket actives.""" |
|
|
def __init__(self): |
|
|
self.active_connections: List[WebSocket] = [] |
|
|
|
|
|
self.response_futures: Dict[str, asyncio.Future] = {} |
|
|
|
|
|
async def connect(self, websocket: WebSocket): |
|
|
"""Accepte une nouvelle connexion WebSocket.""" |
|
|
await websocket.accept() |
|
|
self.active_connections.append(websocket) |
|
|
print(f"Nouvelle connexion WebSocket. Total: {len(self.active_connections)}") |
|
|
|
|
|
def disconnect(self, websocket: WebSocket): |
|
|
"""Ferme une connexion WebSocket.""" |
|
|
self.active_connections.remove(websocket) |
|
|
print(f"Déconnexion WebSocket. Total: {len(self.active_connections)}") |
|
|
|
|
|
async def broadcast(self, message: str): |
|
|
"""Envoie un message à tous les clients connectés.""" |
|
|
|
|
|
if self.active_connections: |
|
|
websocket = self.active_connections[0] |
|
|
await websocket.send_text(message) |
|
|
|
|
|
future = asyncio.get_event_loop().create_future() |
|
|
|
|
|
client_id = str(id(websocket)) |
|
|
self.response_futures[client_id] = future |
|
|
return future |
|
|
return None |
|
|
|
|
|
manager = ConnectionManager() |
|
|
|
|
|
def verify_secret(provided_secret: str) -> bool: |
|
|
"""Vérifie si le secret fourni correspond à celui de la variable d'environnement.""" |
|
|
expected_secret = os.getenv("API_SECRET") |
|
|
|
|
|
if not expected_secret: |
|
|
print("ATTENTION: Variable d'environnement API_SECRET non définie!") |
|
|
return False |
|
|
|
|
|
return provided_secret == expected_secret |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def root(): |
|
|
"""Serve the main HTML page.""" |
|
|
try: |
|
|
with open("static/index.html", "r", encoding="utf-8") as f: |
|
|
return HTMLResponse(content=f.read()) |
|
|
except FileNotFoundError: |
|
|
raise HTTPException(status_code=404, detail="index.html not found") |
|
|
|
|
|
@app.post("/v1/mock") |
|
|
async def mock_endpoint(payload: MockRequest): |
|
|
""" |
|
|
Endpoint API qui prend un string et un secret, vérifie le secret, |
|
|
puis transmet via WebSocket, attend une réponse et la retourne. |
|
|
""" |
|
|
|
|
|
start_time = time.monotonic() |
|
|
|
|
|
try: |
|
|
input_string = payload.parameter |
|
|
selected_model = payload.model |
|
|
provided_secret = payload.secret |
|
|
|
|
|
|
|
|
if not verify_secret(provided_secret): |
|
|
print(f"Tentative d'accès avec un secret invalide: '{provided_secret[:10]}...'") |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="Secret invalide. Accès non autorisé." |
|
|
) |
|
|
|
|
|
print(f"Secret vérifié avec succès. Endpoint /v1/mock appelé avec: '{input_string}'") |
|
|
|
|
|
if input_string is None: |
|
|
raise HTTPException(status_code=400, detail="Le paramètre 'parameter' est manquant.") |
|
|
|
|
|
if not manager.active_connections: |
|
|
raise HTTPException(status_code=503, detail="Aucun client WebSocket n'est connecté.") |
|
|
|
|
|
|
|
|
message_data = { |
|
|
"prompt": input_string, |
|
|
"model": selected_model |
|
|
} |
|
|
|
|
|
|
|
|
response_future = await manager.broadcast(json.dumps(message_data)) |
|
|
|
|
|
|
|
|
print("Envoi du message au client WebSocket...") |
|
|
response_future = await manager.broadcast(input_string) |
|
|
|
|
|
if response_future is None: |
|
|
raise HTTPException(status_code=500, detail="Échec de la diffusion du message.") |
|
|
|
|
|
try: |
|
|
|
|
|
websocket_response = await asyncio.wait_for(response_future, timeout=60.0) |
|
|
print(f"Réponse reçue du WebSocket: '{websocket_response}'") |
|
|
end_time = time.monotonic() |
|
|
duration = end_time - start_time |
|
|
print(f"Requête complétée en {duration:.2f} secondes.") |
|
|
return { |
|
|
"response_from_client": websocket_response, |
|
|
"completion_time_in_seconds": round(duration, 2) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
print("Timeout: Aucune réponse du client WebSocket.") |
|
|
raise HTTPException(status_code=408, detail="Timeout: Le client n'a pas répondu à temps.") |
|
|
|
|
|
except HTTPException: |
|
|
|
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"Erreur dans /v1/mock: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Une erreur interne est survenue: {str(e)}") |
|
|
|
|
|
@app.websocket("/ws") |
|
|
async def websocket_endpoint(websocket: WebSocket): |
|
|
"""Gère la communication WebSocket avec le client.""" |
|
|
await manager.connect(websocket) |
|
|
try: |
|
|
while True: |
|
|
|
|
|
data = await websocket.receive_text() |
|
|
print(f"Message reçu du client: '{data}'") |
|
|
|
|
|
|
|
|
client_id = str(id(websocket)) |
|
|
if client_id in manager.response_futures: |
|
|
manager.response_futures[client_id].set_result(data) |
|
|
del manager.response_futures[client_id] |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
manager.disconnect(websocket) |
|
|
print("Client déconnecté.") |
|
|
except Exception as e: |
|
|
print(f"Erreur dans le WebSocket: {e}") |
|
|
manager.disconnect(websocket) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |